Compare commits

...

8 Commits

Author SHA1 Message Date
Eugene Yurtsev
fed37ddf96 x 2025-10-13 15:44:56 -04:00
Eugene Yurtsev
9f6b660f00 x 2025-10-13 15:43:53 -04:00
Eugene Yurtsev
00c2380019 x 2025-10-13 14:42:48 -04:00
Eugene Yurtsev
c53ffe4b82 x 2025-10-13 14:42:18 -04:00
Eugene Yurtsev
23f5b0cedf x 2025-10-13 14:34:36 -04:00
Eugene Yurtsev
89e3a10cbd langchain-anthropic middleware 2025-10-13 10:05:14 -04:00
Eugene Yurtsev
fad0e2dbd8 update anthropic pyproject / uv lock 2025-10-13 10:02:46 -04:00
Eugene Yurtsev
6c2f0eb67a update langchain_v1 2025-10-13 10:02:19 -04:00
15 changed files with 1955 additions and 676 deletions

View File

@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware
from .pii import PIIDetectionError, PIIMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_emulator import LLMToolEmulator
@@ -33,8 +32,6 @@ from .types import (
__all__ = [
"AgentMiddleware",
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"ClearToolUsesEdit",
"ContextEditingMiddleware",
"HumanInTheLoopMiddleware",

View File

@@ -8,7 +8,7 @@ with any LangChain chat model.
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import Literal
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Apply context edits before invoking the model via handler (async version)."""
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate": # noqa: S105
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
)
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
return await handler(request)
__all__ = [
"ClearToolUsesEdit",

View File

@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from langchain_core.language_models.chat_models import BaseChatModel
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
continue
raise last_exception
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Try fallback models in sequence on errors (async version).
Args:
request: Initial model request.
handler: Async callback to execute the model.
Returns:
AIMessage from successful model call.
Raises:
Exception: If all models fail, re-raises last exception.
"""
# Try primary model first
last_exception: Exception
try:
return await handler(request)
except Exception as e: # noqa: BLE001
last_exception = e
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return await handler(request)
except Exception as e: # noqa: BLE001
last_exception = e
continue
raise last_exception

View File

@@ -6,7 +6,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Literal
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
else self.system_prompt
)
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt (async version)."""
request.system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)

View File

@@ -1,89 +0,0 @@
"""Anthropic prompt caching middleware."""
from collections.abc import Callable
from typing import Literal
from warnings import warn
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used,
default is 0.
unsupported_model_behavior: The behavior to take when an unsupported model is used.
"ignore" will ignore the unsupported model and continue without caching.
"warn" will warn the user and continue without caching.
"raise" will raise an error and stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None # noqa: N806
msg: str | None = None
if ChatAnthropic is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if msg is not None:
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return handler(request)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)

View File

@@ -233,3 +233,169 @@ def test_exclude_tools_prevents_clearing() -> None:
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
async def test_no_edit_when_below_trigger_async() -> None:
"""Test async version of context editing with no edit when below trigger."""
tool_call_id = "call-1"
ai_message = AIMessage(
content="",
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
)
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
middleware = ContextEditingMiddleware(
edits=[ClearToolUsesEdit(trigger=50)],
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
# The request should have been modified in place
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
async def test_clear_tool_outputs_and_inputs_async() -> None:
"""Test async version of clearing tool outputs and inputs."""
tool_call_id = "call-2"
ai_message = AIMessage(
content=[
{"type": "tool_call", "id": tool_call_id, "name": "search", "args": {"query": "foo"}}
],
tool_calls=[{"id": tool_call_id, "name": "search", "args": {"query": "foo"}}],
)
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
edit = ClearToolUsesEdit(
trigger=50,
clear_at_least=10,
clear_tool_inputs=True,
keep=0,
placeholder="[cleared output]",
)
middleware = ContextEditingMiddleware(edits=[edit])
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
assert cleared_tool.response_metadata["context_editing"]["cleared"] is True
assert isinstance(cleared_ai, AIMessage)
assert cleared_ai.tool_calls[0]["args"] == {}
context_meta = cleared_ai.response_metadata.get("context_editing")
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
async def test_respects_keep_last_tool_results_async() -> None:
"""Test async version respects keep parameter for last tool results."""
conversation: list[AIMessage | ToolMessage] = []
edits = [
("call-a", "tool-output-a" * 5),
("call-b", "tool-output-b" * 5),
("call-c", "tool-output-c" * 5),
]
for call_id, text in edits:
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content=text, tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
middleware = ContextEditingMiddleware(
edits=[
ClearToolUsesEdit(
trigger=50,
keep=1,
placeholder="[cleared]",
)
],
token_count_method="model",
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
async def test_exclude_tools_prevents_clearing_async() -> None:
"""Test async version of excluding tools from clearing."""
search_call = "call-search"
calc_call = "call-calc"
state, request = _make_state_and_request(
[
AIMessage(
content="",
tool_calls=[{"id": search_call, "name": "search", "args": {"query": "foo"}}],
),
ToolMessage(content="search-results" * 20, tool_call_id=search_call),
AIMessage(
content="",
tool_calls=[{"id": calc_call, "name": "calculator", "args": {"a": 1, "b": 2}}],
),
ToolMessage(content="42", tool_call_id=calc_call),
]
)
middleware = ContextEditingMiddleware(
edits=[
ClearToolUsesEdit(
trigger=50,
clear_at_least=10,
keep=0,
exclude_tools=("search",),
placeholder="[cleared]",
)
],
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20
assert isinstance(calc_tool, ToolMessage)
assert calc_tool.content == "[cleared]"

View File

@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError,
)
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
assert captured_request["action_requests"][1]["description"] == "Static description"
# Tests for AnthropicPromptCachingMiddleware
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
result = middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
from typing import cast
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
):
middleware.wrap_model_call(fake_request, mock_handler)
langchain_anthropic = ModuleType("langchain_anthropic")
class MockChatAnthropic:
pass
langchain_anthropic.ChatAnthropic = MockChatAnthropic
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert isinstance(result, AIMessage)
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
in str(w[-1].message)
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
# Tests for SummarizationMiddleware
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""

View File

@@ -0,0 +1,220 @@
"""Unit tests for ModelFallbackMiddleware."""
from __future__ import annotations
from typing import cast
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
def _make_request() -> ModelRequest:
"""Create a minimal ModelRequest for testing."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="primary")]))
return ModelRequest(
model=model,
system_prompt=None,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state=cast("AgentState", {}), # type: ignore[name-defined]
runtime=_fake_runtime(),
model_settings={},
)
def test_primary_model_succeeds() -> None:
"""Test that primary model is used when it succeeds."""
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful model call
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "primary response"
def test_fallback_on_primary_failure() -> None:
"""Test that fallback model is used when primary fails."""
class FailingPrimaryModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Primary model failed")
primary_model = FailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback response"
def test_multiple_fallbacks() -> None:
"""Test that multiple fallback models are tried in sequence."""
class FailingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = FailingModel(messages=iter([AIMessage(content="should not see")]))
fallback1 = FailingModel(messages=iter([AIMessage(content="fallback1")]))
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback2"
def test_all_models_fail() -> None:
"""Test that exception is raised when all models fail."""
class AlwaysFailingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AlwaysFailingModel(messages=iter([]))
fallback_model = AlwaysFailingModel(messages=iter([]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
with pytest.raises(ValueError, match="Model failed"):
middleware.wrap_model_call(request, mock_handler)
# ==============================================================================
# Async Tests
# ==============================================================================
async def test_primary_model_succeeds_async() -> None:
"""Test async version - primary model is used when it succeeds."""
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful async model call
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "primary response"
async def test_fallback_on_primary_failure_async() -> None:
"""Test async version - fallback model is used when primary fails."""
class AsyncFailingPrimaryModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Primary model failed")
primary_model = AsyncFailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback response"
async def test_multiple_fallbacks_async() -> None:
"""Test async version - multiple fallback models are tried in sequence."""
class AsyncFailingModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AsyncFailingModel(messages=iter([AIMessage(content="should not see")]))
fallback1 = AsyncFailingModel(messages=iter([AIMessage(content="fallback1")]))
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback2"
async def test_all_models_fail_async() -> None:
"""Test async version - exception is raised when all models fail."""
class AsyncAlwaysFailingModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AsyncAlwaysFailingModel(messages=iter([]))
fallback_model = AsyncAlwaysFailingModel(messages=iter([]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
with pytest.raises(ValueError, match="Model failed"):
await middleware.awrap_model_call(request, mock_handler)

View File

@@ -0,0 +1,172 @@
"""Unit tests for PlanningMiddleware."""
from __future__ import annotations
from typing import cast
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
def _make_request(system_prompt: str | None = None) -> ModelRequest:
"""Create a minimal ModelRequest for testing."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
return ModelRequest(
model=model,
system_prompt=system_prompt,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state=cast("AgentState", {}), # type: ignore[name-defined]
runtime=_fake_runtime(),
model_settings={},
)
def test_adds_system_prompt_when_none_exists() -> None:
"""Test that middleware adds system prompt when request has none."""
middleware = PlanningMiddleware()
request = _make_request(system_prompt=None)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
def test_appends_to_existing_system_prompt() -> None:
"""Test that middleware appends to existing system prompt."""
existing_prompt = "You are a helpful assistant."
middleware = PlanningMiddleware()
request = _make_request(system_prompt=existing_prompt)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
def test_custom_system_prompt() -> None:
"""Test that middleware uses custom system prompt."""
custom_prompt = "Custom planning instructions"
middleware = PlanningMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
def test_has_write_todos_tool() -> None:
"""Test that middleware registers the write_todos tool."""
middleware = PlanningMiddleware()
# Should have one tool registered
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
def test_custom_tool_description() -> None:
"""Test that middleware uses custom tool description."""
custom_description = "Custom todo tool description"
middleware = PlanningMiddleware(tool_description=custom_description)
# Tool should use custom description
assert len(middleware.tools) == 1
assert middleware.tools[0].description == custom_description
# ==============================================================================
# Async Tests
# ==============================================================================
async def test_adds_system_prompt_when_none_exists_async() -> None:
"""Test async version - middleware adds system prompt when request has none."""
middleware = PlanningMiddleware()
request = _make_request(system_prompt=None)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
async def test_appends_to_existing_system_prompt_async() -> None:
"""Test async version - middleware appends to existing system prompt."""
existing_prompt = "You are a helpful assistant."
middleware = PlanningMiddleware()
request = _make_request(system_prompt=existing_prompt)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
async def test_custom_system_prompt_async() -> None:
"""Test async version - middleware uses custom system prompt."""
custom_prompt = "Custom planning instructions"
middleware = PlanningMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
async def test_handler_called_with_modified_request_async() -> None:
"""Test async version - handler receives the modified request."""
middleware = PlanningMiddleware()
request = _make_request(system_prompt="Original")
handler_called = {"value": False}
received_prompt = {"value": None}
async def mock_handler(req: ModelRequest) -> ModelResponse:
handler_called["value"] = True
received_prompt["value"] = req.system_prompt
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
assert handler_called["value"]
assert received_prompt["value"] is not None
assert "Original" in received_prompt["value"]
assert "write_todos" in received_prompt["value"]

View File

@@ -0,0 +1,9 @@
"""Middleware for Anthropic models."""
from langchain_anthropic.middleware.prompt_caching import (
AnthropicPromptCachingMiddleware,
)
__all__ = [
"AnthropicPromptCachingMiddleware",
]

View File

@@ -0,0 +1,157 @@
"""Anthropic prompt caching middleware.
Requires:
- langchain: For agent middleware framework
- langchain-anthropic: For ChatAnthropic model (already a dependency)
"""
from collections.abc import Awaitable, Callable
from typing import Literal
from warnings import warn
try:
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
except ImportError as e:
msg = (
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
"This middleware is designed for use with LangChain agents. "
"Install it with: pip install langchain"
)
raise ImportError(msg) from e
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are
supported.
min_messages_to_cache: The minimum number of messages until the
cache is used, default is 0.
unsupported_model_behavior: The behavior to take when an
unsupported model is used. "ignore" will ignore the unsupported
model and continue without caching. "warn" will warn the user
and continue without caching. "raise" will raise an error and
stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
chat_anthropic_cls: type | None = ChatAnthropic
except ImportError:
chat_anthropic_cls = None
msg: str | None = None
if chat_anthropic_cls is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, chat_anthropic_cls):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if msg is not None:
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks (async version)."""
try:
from langchain_anthropic import ChatAnthropic
chat_anthropic_cls: type | None = ChatAnthropic
except ImportError:
chat_anthropic_cls = None
msg: str | None = None
if chat_anthropic_cls is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, chat_anthropic_cls):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if msg is not None:
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return await handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return await handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return await handler(request)

View File

@@ -41,6 +41,7 @@ test = [
"vcrpy>=7.0.0,<8.0.0",
"langchain-core",
"langchain-tests",
"langchain",
]
lint = ["ruff>=0.13.1,<0.14.0"]
dev = ["langchain-core"]
@@ -55,6 +56,7 @@ typing = [
[tool.uv.sources]
langchain-core = { path = "../../core", editable = true }
langchain-tests = { path = "../../standard-tests", editable = true }
langchain = { path = "../../langchain_v1", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"

View File

@@ -0,0 +1 @@
"""Tests for Anthropic middleware."""

View File

@@ -0,0 +1,238 @@
"""Tests for Anthropic prompt caching middleware."""
import warnings
from typing import Any, cast
import pytest
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
class FakeToolCallingModel(BaseChatModel):
"""Fake model for testing middleware."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Async top level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Since we're in the langchain-anthropic package, ChatAnthropic is always
# available. Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async() -> None:
"""Test AnthropicPromptCachingMiddleware async path."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")] * 6,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 6},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "1h"}
}
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
await middleware.awrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
"""Test async path respects min_messages_to_cache."""
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
# Test with fewer messages than minimum
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")] * 3,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 3},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Cache control should NOT be added when message count is below minimum
assert fake_request.model_settings == {}

File diff suppressed because it is too large Load Diff