chore(langchain_v1,anthropic): migrate anthropic middleware to langchain_anthropic (#33463)

Migrate prompt caching implementation into langchain_anthropic.middleware
This commit is contained in:
Eugene Yurtsev
2025-10-13 15:12:54 -04:00
committed by GitHub
parent 961f965f0c
commit 1cf851e054
9 changed files with 1292 additions and 673 deletions

View File

@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware
from .pii import PIIDetectionError, PIIMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_emulator import LLMToolEmulator
@@ -33,8 +32,6 @@ from .types import (
__all__ = [
"AgentMiddleware",
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"ClearToolUsesEdit",
"ContextEditingMiddleware",
"HumanInTheLoopMiddleware",

View File

@@ -1,89 +0,0 @@
"""Anthropic prompt caching middleware."""
from collections.abc import Callable
from typing import Literal
from warnings import warn
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used,
default is 0.
unsupported_model_behavior: The behavior to take when an unsupported model is used.
"ignore" will ignore the unsupported model and continue without caching.
"warn" will warn the user and continue without caching.
"raise" will raise an error and stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None # noqa: N806
msg: str | None = None
if ChatAnthropic is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if msg is not None:
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return handler(request)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)

View File

@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError,
)
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
assert captured_request["action_requests"][1]["description"] == "Static description"
# Tests for AnthropicPromptCachingMiddleware
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
result = middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
from typing import cast
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
):
middleware.wrap_model_call(fake_request, mock_handler)
langchain_anthropic = ModuleType("langchain_anthropic")
class MockChatAnthropic:
pass
langchain_anthropic.ChatAnthropic = MockChatAnthropic
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert isinstance(result, AIMessage)
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
in str(w[-1].message)
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
# Tests for SummarizationMiddleware
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""

View File

@@ -0,0 +1,9 @@
"""Middleware for Anthropic models."""
from langchain_anthropic.middleware.prompt_caching import (
AnthropicPromptCachingMiddleware,
)
__all__ = [
"AnthropicPromptCachingMiddleware",
]

View File

@@ -0,0 +1,123 @@
"""Anthropic prompt caching middleware.
Requires:
- langchain: For agent middleware framework
- langchain-anthropic: For ChatAnthropic model (already a dependency)
"""
from collections.abc import Awaitable, Callable
from typing import Literal
from warnings import warn
from langchain_anthropic.chat_models import ChatAnthropic
try:
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
except ImportError as e:
msg = (
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
"This middleware is designed for use with LangChain agents. "
"Install it with: pip install langchain"
)
raise ImportError(msg) from e
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are
supported.
min_messages_to_cache: The minimum number of messages until the
cache is used, default is 0.
unsupported_model_behavior: The behavior to take when an
unsupported model is used. "ignore" will ignore the unsupported
model and continue without caching. "warn" will warn the user
and continue without caching. "raise" will raise an error and
stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
if not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
return handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks (async version)."""
if not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
return await handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return await handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return await handler(request)

View File

@@ -41,6 +41,7 @@ test = [
"vcrpy>=7.0.0,<8.0.0",
"langchain-core",
"langchain-tests",
"langchain",
]
lint = ["ruff>=0.13.1,<0.14.0"]
dev = ["langchain-core"]
@@ -55,6 +56,7 @@ typing = [
[tool.uv.sources]
langchain-core = { path = "../../core", editable = true }
langchain-tests = { path = "../../standard-tests", editable = true }
langchain = { path = "../../langchain_v1", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"

View File

@@ -0,0 +1 @@
"""Tests for Anthropic middleware."""

View File

@@ -0,0 +1,246 @@
"""Tests for Anthropic prompt caching middleware."""
import warnings
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
class FakeToolCallingModel(BaseChatModel):
"""Fake model for testing middleware."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Async top level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
# Create a mock ChatAnthropic instance
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
fake_request = ModelRequest(
model=mock_chat_anthropic,
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Since we're in the langchain-anthropic package, ChatAnthropic is always
# available. Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async() -> None:
"""Test AnthropicPromptCachingMiddleware async path."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
# Create a mock ChatAnthropic instance
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
fake_request = ModelRequest(
model=mock_chat_anthropic,
messages=[HumanMessage("Hello")] * 6,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 6},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "1h"}
}
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
await middleware.awrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
"""Test async path respects min_messages_to_cache."""
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
# Test with fewer messages than minimum
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")] * 3,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 3},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Cache control should NOT be added when message count is below minimum
assert fake_request.model_settings == {}

File diff suppressed because it is too large Load Diff