Compare commits

...

6 Commits

Author SHA1 Message Date
Eugene Yurtsev
c04676f93e x 2025-10-13 15:18:57 -04:00
Eugene Yurtsev
9358d12917 x 2025-10-13 15:16:45 -04:00
Eugene Yurtsev
23f5b0cedf x 2025-10-13 14:34:36 -04:00
Eugene Yurtsev
89e3a10cbd langchain-anthropic middleware 2025-10-13 10:05:14 -04:00
Eugene Yurtsev
fad0e2dbd8 update anthropic pyproject / uv lock 2025-10-13 10:02:46 -04:00
Eugene Yurtsev
6c2f0eb67a update langchain_v1 2025-10-13 10:02:19 -04:00
11 changed files with 1542 additions and 605 deletions

View File

@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware
from .pii import PIIDetectionError, PIIMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_emulator import LLMToolEmulator
@@ -33,8 +32,6 @@ from .types import (
__all__ = [
"AgentMiddleware",
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"ClearToolUsesEdit",
"ContextEditingMiddleware",
"HumanInTheLoopMiddleware",

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from inspect import iscoroutinefunction
from typing import (
TYPE_CHECKING,
@@ -30,7 +30,7 @@ from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
@@ -62,6 +62,18 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for ModelRequest.override() method."""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage]
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
model_settings: dict[str, Any]
@dataclass
class ModelRequest:
"""Model request information for the agent."""
@@ -76,6 +88,36 @@ class ModelRequest:
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
"""Replace the request with a new request with the given overrides.
Returns a new `ModelRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override. Supported keys:
- model: BaseChatModel instance
- system_prompt: Optional system prompt string
- messages: List of messages
- tool_choice: Tool choice configuration
- tools: List of available tools
- response_format: Response format specification
- model_settings: Additional model settings
Returns:
New ModelRequest instance with specified overrides applied.
Examples:
```python
# Create a new request with different model
new_request = request.override(model=different_model)
# Override multiple attributes
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
```
"""
return replace(self, **overrides)
@dataclass
class ModelResponse:

View File

@@ -81,6 +81,7 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import get_runtime
from langgraph.types import Command, Send
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
if TYPE_CHECKING:
from collections.abc import Sequence
@@ -104,6 +105,12 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
)
class _ToolCallRequestOverrides(TypedDict, total=False):
"""Possible overrides for ToolCallRequest.override() method."""
tool_call: ToolCall
@dataclass()
class ToolCallRequest:
"""Tool execution request passed to tool call interceptors.
@@ -120,6 +127,31 @@ class ToolCallRequest:
state: Any
runtime: Any
def override(self, **overrides: Unpack[_ToolCallRequestOverrides]) -> ToolCallRequest:
"""Replace the request with a new request with the given overrides.
Returns a new `ToolCallRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override. Supported keys:
- tool_call: Tool call dict with name, args, and id
Returns:
New ToolCallRequest instance with specified overrides applied.
Examples:
```python
# Modify tool call arguments without mutating original
modified_call = {**request.tool_call, "args": {"value": 10}}
new_request = request.override(tool_call=modified_call)
# Override multiple attributes
new_request = request.override(tool_call=modified_call, state=new_state)
```
"""
return replace(self, **overrides)
ToolCallWrapper = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],

View File

@@ -0,0 +1,381 @@
"""Unit tests for override() methods on ModelRequest and ToolCallRequest."""
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool
from langchain.agents.middleware.types import ModelRequest
from langchain.tools.tool_node import ToolCallRequest
class TestModelRequestOverride:
"""Test the ModelRequest.override() method."""
def test_override_single_attribute(self) -> None:
"""Test overriding a single attribute."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(system_prompt="New prompt")
# New request should have the overridden value
assert new_request.system_prompt == "New prompt"
# Original request should be unchanged (immutability)
assert original_request.system_prompt == "Original prompt"
# Other attributes should be the same
assert new_request.model == original_request.model
assert new_request.messages == original_request.messages
def test_override_multiple_attributes(self) -> None:
"""Test overriding multiple attributes at once."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={"count": 1},
runtime=None,
)
new_request = original_request.override(
system_prompt="New prompt",
tool_choice="auto",
state={"count": 2},
)
# Overridden values should be changed
assert new_request.system_prompt == "New prompt"
assert new_request.tool_choice == "auto"
assert new_request.state == {"count": 2}
# Original should be unchanged
assert original_request.system_prompt == "Original prompt"
assert original_request.tool_choice is None
assert original_request.state == {"count": 1}
def test_override_messages(self) -> None:
"""Test overriding messages list."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_messages = [HumanMessage("Hi")]
new_messages = [HumanMessage("Hello"), AIMessage("Hi there")]
original_request = ModelRequest(
model=model,
system_prompt=None,
messages=original_messages,
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(messages=new_messages)
assert new_request.messages == new_messages
assert original_request.messages == original_messages
assert len(new_request.messages) == 2
assert len(original_request.messages) == 1
def test_override_model_settings(self) -> None:
"""Test overriding model_settings dict."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt=None,
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
model_settings={"temperature": 0.5},
)
new_request = original_request.override(
model_settings={"temperature": 0.9, "max_tokens": 100}
)
assert new_request.model_settings == {"temperature": 0.9, "max_tokens": 100}
assert original_request.model_settings == {"temperature": 0.5}
def test_override_with_none_value(self) -> None:
"""Test overriding with None value."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice="auto",
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(
system_prompt=None,
tool_choice=None,
)
assert new_request.system_prompt is None
assert new_request.tool_choice is None
assert original_request.system_prompt == "Original prompt"
assert original_request.tool_choice == "auto"
def test_override_preserves_identity_of_unchanged_objects(self) -> None:
"""Test that unchanged attributes maintain object identity."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
messages = [HumanMessage("Hi")]
state = {"key": "value"}
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=messages,
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=None,
)
new_request = original_request.override(system_prompt="New prompt")
# Unchanged objects should be the same instance
assert new_request.messages is messages
assert new_request.state is state
assert new_request.model is model
def test_override_chaining(self) -> None:
"""Test chaining multiple override calls."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Prompt 1",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={"count": 1},
runtime=None,
)
final_request = (
original_request.override(system_prompt="Prompt 2")
.override(state={"count": 2})
.override(tool_choice="auto")
)
assert final_request.system_prompt == "Prompt 2"
assert final_request.state == {"count": 2}
assert final_request.tool_choice == "auto"
# Original should be unchanged
assert original_request.system_prompt == "Prompt 1"
assert original_request.state == {"count": 1}
assert original_request.tool_choice is None
class TestToolCallRequestOverride:
"""Test the ToolCallRequest.override() method."""
def test_override_tool_call(self) -> None:
"""Test overriding tool_call dict."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
modified_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={"messages": []},
runtime=None,
)
new_request = original_request.override(tool_call=modified_call)
# New request should have modified tool_call
assert new_request.tool_call["args"]["x"] == 10
# Original should be unchanged
assert original_request.tool_call["args"]["x"] == 5
# Other attributes should be the same
assert new_request.tool is original_request.tool
assert new_request.state is original_request.state
def test_override_state(self) -> None:
"""Test overriding state."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
original_state = {"messages": [HumanMessage("Hi")]}
new_state = {"messages": [HumanMessage("Hi"), AIMessage("Hello")]}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state=original_state,
runtime=None,
)
new_request = original_request.override(state=new_state)
assert len(new_request.state["messages"]) == 2
assert len(original_request.state["messages"]) == 1
def test_override_multiple_attributes(self) -> None:
"""Test overriding multiple attributes at once."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
@tool
def another_tool(y: str) -> str:
"""Another test tool."""
return f"Output: {y}"
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
modified_call = {
"name": "another_tool",
"args": {"y": "hello"},
"id": "2",
"type": "tool_call",
}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={"count": 1},
runtime=None,
)
new_request = original_request.override(
tool_call=modified_call,
tool=another_tool,
state={"count": 2},
)
assert new_request.tool_call["name"] == "another_tool"
assert new_request.tool.name == "another_tool"
assert new_request.state == {"count": 2}
# Original unchanged
assert original_request.tool_call["name"] == "test_tool"
assert original_request.tool.name == "test_tool"
assert original_request.state == {"count": 1}
def test_override_with_copy_pattern(self) -> None:
"""Test common pattern of copying and modifying tool_call."""
from langchain_core.tools import tool
@tool
def test_tool(value: int) -> str:
"""A test tool."""
return f"Result: {value}"
original_call = {
"name": "test_tool",
"args": {"value": 5},
"id": "call_123",
"type": "tool_call",
}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={},
runtime=None,
)
# Common pattern: copy tool_call and modify args
modified_call = {**original_request.tool_call, "args": {"value": 10}}
new_request = original_request.override(tool_call=modified_call)
assert new_request.tool_call["args"]["value"] == 10
assert new_request.tool_call["id"] == "call_123"
assert new_request.tool_call["name"] == "test_tool"
# Original unchanged
assert original_request.tool_call["args"]["value"] == 5
def test_override_preserves_identity(self) -> None:
"""Test that unchanged attributes maintain object identity."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
state = {"messages": []}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state=state,
runtime=None,
)
new_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
new_request = original_request.override(tool_call=new_call)
# Unchanged objects should be the same instance
assert new_request.tool is test_tool
assert new_request.state is state
def test_override_chaining(self) -> None:
"""Test chaining multiple override calls."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state={"count": 1},
runtime=None,
)
call_2 = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
call_3 = {"name": "test_tool", "args": {"x": 15}, "id": "1", "type": "tool_call"}
final_request = (
original_request.override(tool_call=call_2)
.override(state={"count": 2})
.override(tool_call=call_3)
)
assert final_request.tool_call["args"]["x"] == 15
assert final_request.state == {"count": 2}
# Original unchanged
assert original_request.tool_call["args"]["x"] == 5
assert original_request.state == {"count": 1}

View File

@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError,
)
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
assert captured_request["action_requests"][1]["description"] == "Static description"
# Tests for AnthropicPromptCachingMiddleware
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
result = middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
from typing import cast
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
):
middleware.wrap_model_call(fake_request, mock_handler)
langchain_anthropic = ModuleType("langchain_anthropic")
class MockChatAnthropic:
pass
langchain_anthropic.ChatAnthropic = MockChatAnthropic
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert isinstance(result, AIMessage)
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
in str(w[-1].message)
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
# Tests for SummarizationMiddleware
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""

View File

@@ -0,0 +1,9 @@
"""Middleware for Anthropic models."""
from langchain_anthropic.middleware.prompt_caching import (
AnthropicPromptCachingMiddleware,
)
__all__ = [
"AnthropicPromptCachingMiddleware",
]

View File

@@ -1,15 +1,28 @@
"""Anthropic prompt caching middleware."""
"""Anthropic prompt caching middleware.
Requires:
- langchain: For agent middleware framework
- langchain-anthropic: For ChatAnthropic model (already a dependency)
"""
from collections.abc import Callable
from typing import Literal
from warnings import warn
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
try:
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
except ImportError as e:
msg = (
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
"This middleware is designed for use with LangChain agents. "
"Install it with: pip install langchain"
)
raise ImportError(msg) from e
class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -17,13 +30,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
Optimizes API usage by caching conversation prefixes for Anthropic models.
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
@@ -32,13 +47,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used,
default is 0.
unsupported_model_behavior: The behavior to take when an unsupported model is used.
"ignore" will ignore the unsupported model and continue without caching.
"warn" will warn the user and continue without caching.
"raise" will raise an error and stop the agent.
ttl: The time to live for the cache, only "5m" and "1h" are
supported.
min_messages_to_cache: The minimum number of messages until the
cache is used, default is 0.
unsupported_model_behavior: The behavior to take when an
unsupported model is used. "ignore" will ignore the unsupported
model and continue without caching. "warn" will warn the user
and continue without caching. "raise" will raise an error and
stop the agent.
"""
self.type = type
self.ttl = ttl
@@ -53,18 +70,20 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
chat_anthropic_cls: type | None = ChatAnthropic
except ImportError:
ChatAnthropic = None # noqa: N806
chat_anthropic_cls = None
msg: str | None = None
if ChatAnthropic is None:
if chat_anthropic_cls is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, ChatAnthropic):
elif not isinstance(request.model, chat_anthropic_cls):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
@@ -79,7 +98,9 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
return handler(request)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)

View File

@@ -41,6 +41,7 @@ test = [
"vcrpy>=7.0.0,<8.0.0",
"langchain-core",
"langchain-tests",
"langchain",
]
lint = ["ruff>=0.13.1,<0.14.0"]
dev = ["langchain-core"]
@@ -55,6 +56,7 @@ typing = [
[tool.uv.sources]
langchain-core = { path = "../../core", editable = true }
langchain-tests = { path = "../../standard-tests", editable = true }
langchain = { path = "../../langchain_v1", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"

View File

@@ -0,0 +1 @@
"""Tests for Anthropic middleware."""

View File

@@ -0,0 +1,122 @@
"""Tests for Anthropic prompt caching middleware."""
import warnings
from typing import Any, cast
import pytest
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
class FakeToolCallingModel(BaseChatModel):
"""Fake model for testing middleware."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(
result=[AIMessage(content="mock response")]
)
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Since we're in the langchain-anthropic package, ChatAnthropic is always
# available. Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)

File diff suppressed because it is too large Load Diff