mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain_v1,anthropic): migrate anthropic middleware to langchain_anthropic (#33463)
Migrate prompt caching implementation into langchain_anthropic.middleware
This commit is contained in:
@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .tool_call_limit import ToolCallLimitMiddleware
|
||||
from .tool_emulator import LLMToolEmulator
|
||||
@@ -33,8 +32,6 @@ from .types import (
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Anthropic prompt caching middleware."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral",
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
||||
min_messages_to_cache: The minimum number of messages until the cache is used,
|
||||
default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
||||
"ignore" will ignore the unsupported model and continue without caching.
|
||||
"warn" will warn the user and continue without caching.
|
||||
"raise" will raise an error and stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
except ImportError:
|
||||
ChatAnthropic = None # noqa: N806
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if ChatAnthropic is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response", **req.model_settings)
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
from typing import cast
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
||||
|
||||
class MockChatAnthropic:
|
||||
pass
|
||||
|
||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
in str(w[-1].message)
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Middleware for Anthropic models."""
|
||||
|
||||
from langchain_anthropic.middleware.prompt_caching import (
|
||||
AnthropicPromptCachingMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Anthropic prompt caching middleware.
|
||||
|
||||
Requires:
|
||||
- langchain: For agent middleware framework
|
||||
- langchain-anthropic: For ChatAnthropic model (already a dependency)
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
try:
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
|
||||
"This middleware is designed for use with LangChain agents. "
|
||||
"Install it with: pip install langchain"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are
|
||||
supported.
|
||||
min_messages_to_cache: The minimum number of messages until the
|
||||
cache is used, default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an
|
||||
unsupported model is used. "ignore" will ignore the unsupported
|
||||
model and continue without caching. "warn" will warn the user
|
||||
and continue without caching. "raise" will raise an error and
|
||||
stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
if not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks (async version)."""
|
||||
if not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
return await handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return await handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return await handler(request)
|
||||
@@ -41,6 +41,7 @@ test = [
|
||||
"vcrpy>=7.0.0,<8.0.0",
|
||||
"langchain-core",
|
||||
"langchain-tests",
|
||||
"langchain",
|
||||
]
|
||||
lint = ["ruff>=0.13.1,<0.14.0"]
|
||||
dev = ["langchain-core"]
|
||||
@@ -55,6 +56,7 @@ typing = [
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../../core", editable = true }
|
||||
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
langchain = { path = "../../langchain_v1", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Anthropic middleware."""
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Tests for Anthropic prompt caching middleware."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel):
|
||||
"""Fake model for testing middleware."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Async top level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
# Create a mock ChatAnthropic instance
|
||||
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=mock_chat_anthropic,
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Since we're in the langchain-anthropic package, ChatAnthropic is always
|
||||
# available. Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
|
||||
# Create a mock ChatAnthropic instance
|
||||
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=mock_chat_anthropic,
|
||||
messages=[HumanMessage("Hello")] * 6,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 6},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||
}
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
|
||||
"""Test async path respects min_messages_to_cache."""
|
||||
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
|
||||
|
||||
# Test with fewer messages than minimum
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")] * 3,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 3},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Cache control should NOT be added when message count is below minimum
|
||||
assert fake_request.model_settings == {}
|
||||
1382
libs/partners/anthropic/uv.lock
generated
1382
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user