mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 00:00:34 +00:00
Compare commits
8 Commits
langchain-
...
async_impl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fed37ddf96 | ||
|
|
9f6b660f00 | ||
|
|
00c2380019 | ||
|
|
c53ffe4b82 | ||
|
|
23f5b0cedf | ||
|
|
89e3a10cbd | ||
|
|
fad0e2dbd8 | ||
|
|
6c2f0eb67a |
@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .tool_call_limit import ToolCallLimitMiddleware
|
||||
from .tool_emulator import LLMToolEmulator
|
||||
@@ -33,8 +32,6 @@ from .types import (
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
|
||||
@@ -8,7 +8,7 @@ with any LangChain chat model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Apply context edits before invoking the model via handler (async version)."""
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Try fallback models in sequence on errors (async version).
|
||||
|
||||
Args:
|
||||
request: Initial model request.
|
||||
handler: Async callback to execute the model.
|
||||
|
||||
Returns:
|
||||
AIMessage from successful model call.
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail, re-raises last exception.
|
||||
"""
|
||||
# Try primary model first
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||
request.system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return await handler(request)
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Anthropic prompt caching middleware."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral",
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
||||
min_messages_to_cache: The minimum number of messages until the cache is used,
|
||||
default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
||||
"ignore" will ignore the unsupported model and continue without caching.
|
||||
"warn" will warn the user and continue without caching.
|
||||
"raise" will raise an error and stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
except ImportError:
|
||||
ChatAnthropic = None # noqa: N806
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if ChatAnthropic is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
@@ -233,3 +233,169 @@ def test_exclude_tools_prevents_clearing() -> None:
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
async def test_no_edit_when_below_trigger_async() -> None:
|
||||
"""Test async version of context editing with no edit when below trigger."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
"""Test async version of clearing tool outputs and inputs."""
|
||||
tool_call_id = "call-2"
|
||||
ai_message = AIMessage(
|
||||
content=[
|
||||
{"type": "tool_call", "id": tool_call_id, "name": "search", "args": {"query": "foo"}}
|
||||
],
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {"query": "foo"}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
clear_at_least=10,
|
||||
clear_tool_inputs=True,
|
||||
keep=0,
|
||||
placeholder="[cleared output]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
assert cleared_tool.response_metadata["context_editing"]["cleared"] is True
|
||||
|
||||
assert isinstance(cleared_ai, AIMessage)
|
||||
assert cleared_ai.tool_calls[0]["args"] == {}
|
||||
context_meta = cleared_ai.response_metadata.get("context_editing")
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_respects_keep_last_tool_results_async() -> None:
|
||||
"""Test async version respects keep parameter for last tool results."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
edits = [
|
||||
("call-a", "tool-output-a" * 5),
|
||||
("call-b", "tool-output-b" * 5),
|
||||
("call-c", "tool-output-c" * 5),
|
||||
]
|
||||
|
||||
for call_id, text in edits:
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content=text, tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[
|
||||
ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
keep=1,
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
],
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
"""Test async version of excluding tools from clearing."""
|
||||
search_call = "call-search"
|
||||
calc_call = "call-calc"
|
||||
|
||||
state, request = _make_state_and_request(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": search_call, "name": "search", "args": {"query": "foo"}}],
|
||||
),
|
||||
ToolMessage(content="search-results" * 20, tool_call_id=search_call),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": calc_call, "name": "calculator", "args": {"a": 1, "b": 2}}],
|
||||
),
|
||||
ToolMessage(content="42", tool_call_id=calc_call),
|
||||
]
|
||||
)
|
||||
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[
|
||||
ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
clear_at_least=10,
|
||||
keep=0,
|
||||
exclude_tools=("search",),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
|
||||
assert isinstance(calc_tool, ToolMessage)
|
||||
assert calc_tool.content == "[cleared]"
|
||||
|
||||
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response", **req.model_settings)
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
from typing import cast
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
||||
|
||||
class MockChatAnthropic:
|
||||
pass
|
||||
|
||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
in str(w[-1].message)
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
"""Unit tests for ModelFallbackMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
def _make_request() -> ModelRequest:
|
||||
"""Create a minimal ModelRequest for testing."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="primary")]))
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_primary_model_succeeds() -> None:
|
||||
"""Test that primary model is used when it succeeds."""
|
||||
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful model call
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "primary response"
|
||||
|
||||
|
||||
def test_fallback_on_primary_failure() -> None:
|
||||
"""Test that fallback model is used when primary fails."""
|
||||
|
||||
class FailingPrimaryModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
primary_model = FailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback response"
|
||||
|
||||
|
||||
def test_multiple_fallbacks() -> None:
|
||||
"""Test that multiple fallback models are tried in sequence."""
|
||||
|
||||
class FailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = FailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback1 = FailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback2"
|
||||
|
||||
|
||||
def test_all_models_fail() -> None:
|
||||
"""Test that exception is raised when all models fail."""
|
||||
|
||||
class AlwaysFailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AlwaysFailingModel(messages=iter([]))
|
||||
fallback_model = AlwaysFailingModel(messages=iter([]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Async Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
async def test_primary_model_succeeds_async() -> None:
|
||||
"""Test async version - primary model is used when it succeeds."""
|
||||
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful async model call
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "primary response"
|
||||
|
||||
|
||||
async def test_fallback_on_primary_failure_async() -> None:
|
||||
"""Test async version - fallback model is used when primary fails."""
|
||||
|
||||
class AsyncFailingPrimaryModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
primary_model = AsyncFailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback response"
|
||||
|
||||
|
||||
async def test_multiple_fallbacks_async() -> None:
|
||||
"""Test async version - multiple fallback models are tried in sequence."""
|
||||
|
||||
class AsyncFailingModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AsyncFailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback1 = AsyncFailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback2"
|
||||
|
||||
|
||||
async def test_all_models_fail_async() -> None:
|
||||
"""Test async version - exception is raised when all models fail."""
|
||||
|
||||
class AsyncAlwaysFailingModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||
fallback_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Unit tests for PlanningMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.planning import PlanningMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
def _make_request(system_prompt: str | None = None) -> ModelRequest:
|
||||
"""Create a minimal ModelRequest for testing."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_adds_system_prompt_when_none_exists() -> None:
|
||||
"""Test that middleware adds system prompt when request has none."""
|
||||
middleware = PlanningMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
|
||||
|
||||
def test_appends_to_existing_system_prompt() -> None:
|
||||
"""Test that middleware appends to existing system prompt."""
|
||||
existing_prompt = "You are a helpful assistant."
|
||||
middleware = PlanningMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
|
||||
|
||||
def test_custom_system_prompt() -> None:
|
||||
"""Test that middleware uses custom system prompt."""
|
||||
custom_prompt = "Custom planning instructions"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
|
||||
|
||||
def test_has_write_todos_tool() -> None:
|
||||
"""Test that middleware registers the write_todos tool."""
|
||||
middleware = PlanningMiddleware()
|
||||
|
||||
# Should have one tool registered
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
|
||||
|
||||
def test_custom_tool_description() -> None:
|
||||
"""Test that middleware uses custom tool description."""
|
||||
custom_description = "Custom todo tool description"
|
||||
middleware = PlanningMiddleware(tool_description=custom_description)
|
||||
|
||||
# Tool should use custom description
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].description == custom_description
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Async Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
async def test_adds_system_prompt_when_none_exists_async() -> None:
|
||||
"""Test async version - middleware adds system prompt when request has none."""
|
||||
middleware = PlanningMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
|
||||
|
||||
async def test_appends_to_existing_system_prompt_async() -> None:
|
||||
"""Test async version - middleware appends to existing system prompt."""
|
||||
existing_prompt = "You are a helpful assistant."
|
||||
middleware = PlanningMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
|
||||
|
||||
async def test_custom_system_prompt_async() -> None:
|
||||
"""Test async version - middleware uses custom system prompt."""
|
||||
custom_prompt = "Custom planning instructions"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
|
||||
|
||||
async def test_handler_called_with_modified_request_async() -> None:
|
||||
"""Test async version - handler receives the modified request."""
|
||||
middleware = PlanningMiddleware()
|
||||
request = _make_request(system_prompt="Original")
|
||||
handler_called = {"value": False}
|
||||
received_prompt = {"value": None}
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
handler_called["value"] = True
|
||||
received_prompt["value"] = req.system_prompt
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert handler_called["value"]
|
||||
assert received_prompt["value"] is not None
|
||||
assert "Original" in received_prompt["value"]
|
||||
assert "write_todos" in received_prompt["value"]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Middleware for Anthropic models."""
|
||||
|
||||
from langchain_anthropic.middleware.prompt_caching import (
|
||||
AnthropicPromptCachingMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Anthropic prompt caching middleware.
|
||||
|
||||
Requires:
|
||||
- langchain: For agent middleware framework
|
||||
- langchain-anthropic: For ChatAnthropic model (already a dependency)
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
try:
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
|
||||
"This middleware is designed for use with LangChain agents. "
|
||||
"Install it with: pip install langchain"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are
|
||||
supported.
|
||||
min_messages_to_cache: The minimum number of messages until the
|
||||
cache is used, default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an
|
||||
unsupported model is used. "ignore" will ignore the unsupported
|
||||
model and continue without caching. "warn" will warn the user
|
||||
and continue without caching. "raise" will raise an error and
|
||||
stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
chat_anthropic_cls: type | None = ChatAnthropic
|
||||
except ImportError:
|
||||
chat_anthropic_cls = None
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if chat_anthropic_cls is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, chat_anthropic_cls):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks (async version)."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
chat_anthropic_cls: type | None = ChatAnthropic
|
||||
except ImportError:
|
||||
chat_anthropic_cls = None
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if chat_anthropic_cls is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, chat_anthropic_cls):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return await handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return await handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return await handler(request)
|
||||
@@ -41,6 +41,7 @@ test = [
|
||||
"vcrpy>=7.0.0,<8.0.0",
|
||||
"langchain-core",
|
||||
"langchain-tests",
|
||||
"langchain",
|
||||
]
|
||||
lint = ["ruff>=0.13.1,<0.14.0"]
|
||||
dev = ["langchain-core"]
|
||||
@@ -55,6 +56,7 @@ typing = [
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../../core", editable = true }
|
||||
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
langchain = { path = "../../langchain_v1", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Anthropic middleware."""
|
||||
@@ -0,0 +1,238 @@
|
||||
"""Tests for Anthropic prompt caching middleware."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel):
|
||||
"""Fake model for testing middleware."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Async top level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Since we're in the langchain-anthropic package, ChatAnthropic is always
|
||||
# available. Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")] * 6,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 6},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||
}
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
|
||||
"""Test async path respects min_messages_to_cache."""
|
||||
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
|
||||
|
||||
# Test with fewer messages than minimum
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")] * 3,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 3},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Cache control should NOT be added when message count is below minimum
|
||||
assert fake_request.model_settings == {}
|
||||
1382
libs/partners/anthropic/uv.lock
generated
1382
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user