mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
6 Commits
mdrxy/fix-
...
replace_ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c04676f93e | ||
|
|
9358d12917 | ||
|
|
23f5b0cedf | ||
|
|
89e3a10cbd | ||
|
|
fad0e2dbd8 | ||
|
|
6c2f0eb67a |
@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .tool_call_limit import ToolCallLimitMiddleware
|
||||
from .tool_emulator import LLMToolEmulator
|
||||
@@ -33,8 +32,6 @@ from .types import (
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -30,7 +30,7 @@ from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command # noqa: TC002
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -62,6 +62,18 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for ModelRequest.override() method."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage]
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
@@ -76,6 +88,36 @@ class ModelRequest:
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override. Supported keys:
|
||||
- model: BaseChatModel instance
|
||||
- system_prompt: Optional system prompt string
|
||||
- messages: List of messages
|
||||
- tool_choice: Tool choice configuration
|
||||
- tools: List of available tools
|
||||
- response_format: Response format specification
|
||||
- model_settings: Additional model settings
|
||||
|
||||
Returns:
|
||||
New ModelRequest instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Create a new request with different model
|
||||
new_request = request.override(model=different_model)
|
||||
|
||||
# Override multiple attributes
|
||||
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
|
||||
@@ -81,6 +81,7 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import get_runtime
|
||||
from langgraph.types import Command, Send
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -104,6 +105,12 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for ToolCallRequest.override() method."""
|
||||
|
||||
tool_call: ToolCall
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ToolCallRequest:
|
||||
"""Tool execution request passed to tool call interceptors.
|
||||
@@ -120,6 +127,31 @@ class ToolCallRequest:
|
||||
state: Any
|
||||
runtime: Any
|
||||
|
||||
def override(self, **overrides: Unpack[_ToolCallRequestOverrides]) -> ToolCallRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ToolCallRequest` instance with the specified attributes replaced.
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override. Supported keys:
|
||||
- tool_call: Tool call dict with name, args, and id
|
||||
|
||||
Returns:
|
||||
New ToolCallRequest instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Modify tool call arguments without mutating original
|
||||
modified_call = {**request.tool_call, "args": {"value": 10}}
|
||||
new_request = request.override(tool_call=modified_call)
|
||||
|
||||
# Override multiple attributes
|
||||
new_request = request.override(tool_call=modified_call, state=new_state)
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
ToolCallWrapper = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
"""Unit tests for override() methods on ModelRequest and ToolCallRequest."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
|
||||
class TestModelRequestOverride:
|
||||
"""Test the ModelRequest.override() method."""
|
||||
|
||||
def test_override_single_attribute(self) -> None:
|
||||
"""Test overriding a single attribute."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(system_prompt="New prompt")
|
||||
|
||||
# New request should have the overridden value
|
||||
assert new_request.system_prompt == "New prompt"
|
||||
# Original request should be unchanged (immutability)
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
# Other attributes should be the same
|
||||
assert new_request.model == original_request.model
|
||||
assert new_request.messages == original_request.messages
|
||||
|
||||
def test_override_multiple_attributes(self) -> None:
|
||||
"""Test overriding multiple attributes at once."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
system_prompt="New prompt",
|
||||
tool_choice="auto",
|
||||
state={"count": 2},
|
||||
)
|
||||
|
||||
# Overridden values should be changed
|
||||
assert new_request.system_prompt == "New prompt"
|
||||
assert new_request.tool_choice == "auto"
|
||||
assert new_request.state == {"count": 2}
|
||||
# Original should be unchanged
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
assert original_request.tool_choice is None
|
||||
assert original_request.state == {"count": 1}
|
||||
|
||||
def test_override_messages(self) -> None:
|
||||
"""Test overriding messages list."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_messages = [HumanMessage("Hi")]
|
||||
new_messages = [HumanMessage("Hello"), AIMessage("Hi there")]
|
||||
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=original_messages,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(messages=new_messages)
|
||||
|
||||
assert new_request.messages == new_messages
|
||||
assert original_request.messages == original_messages
|
||||
assert len(new_request.messages) == 2
|
||||
assert len(original_request.messages) == 1
|
||||
|
||||
def test_override_model_settings(self) -> None:
|
||||
"""Test overriding model_settings dict."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
model_settings={"temperature": 0.5},
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
model_settings={"temperature": 0.9, "max_tokens": 100}
|
||||
)
|
||||
|
||||
assert new_request.model_settings == {"temperature": 0.9, "max_tokens": 100}
|
||||
assert original_request.model_settings == {"temperature": 0.5}
|
||||
|
||||
def test_override_with_none_value(self) -> None:
|
||||
"""Test overriding with None value."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice="auto",
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert new_request.system_prompt is None
|
||||
assert new_request.tool_choice is None
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
assert original_request.tool_choice == "auto"
|
||||
|
||||
def test_override_preserves_identity_of_unchanged_objects(self) -> None:
|
||||
"""Test that unchanged attributes maintain object identity."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
messages = [HumanMessage("Hi")]
|
||||
state = {"key": "value"}
|
||||
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=messages,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(system_prompt="New prompt")
|
||||
|
||||
# Unchanged objects should be the same instance
|
||||
assert new_request.messages is messages
|
||||
assert new_request.state is state
|
||||
assert new_request.model is model
|
||||
|
||||
def test_override_chaining(self) -> None:
|
||||
"""Test chaining multiple override calls."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Prompt 1",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
final_request = (
|
||||
original_request.override(system_prompt="Prompt 2")
|
||||
.override(state={"count": 2})
|
||||
.override(tool_choice="auto")
|
||||
)
|
||||
|
||||
assert final_request.system_prompt == "Prompt 2"
|
||||
assert final_request.state == {"count": 2}
|
||||
assert final_request.tool_choice == "auto"
|
||||
# Original should be unchanged
|
||||
assert original_request.system_prompt == "Prompt 1"
|
||||
assert original_request.state == {"count": 1}
|
||||
assert original_request.tool_choice is None
|
||||
|
||||
|
||||
class TestToolCallRequestOverride:
|
||||
"""Test the ToolCallRequest.override() method."""
|
||||
|
||||
def test_override_tool_call(self) -> None:
|
||||
"""Test overriding tool_call dict."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
modified_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={"messages": []},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(tool_call=modified_call)
|
||||
|
||||
# New request should have modified tool_call
|
||||
assert new_request.tool_call["args"]["x"] == 10
|
||||
# Original should be unchanged
|
||||
assert original_request.tool_call["args"]["x"] == 5
|
||||
# Other attributes should be the same
|
||||
assert new_request.tool is original_request.tool
|
||||
assert new_request.state is original_request.state
|
||||
|
||||
def test_override_state(self) -> None:
|
||||
"""Test overriding state."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
original_state = {"messages": [HumanMessage("Hi")]}
|
||||
new_state = {"messages": [HumanMessage("Hi"), AIMessage("Hello")]}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state=original_state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(state=new_state)
|
||||
|
||||
assert len(new_request.state["messages"]) == 2
|
||||
assert len(original_request.state["messages"]) == 1
|
||||
|
||||
def test_override_multiple_attributes(self) -> None:
|
||||
"""Test overriding multiple attributes at once."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
@tool
|
||||
def another_tool(y: str) -> str:
|
||||
"""Another test tool."""
|
||||
return f"Output: {y}"
|
||||
|
||||
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
modified_call = {
|
||||
"name": "another_tool",
|
||||
"args": {"y": "hello"},
|
||||
"id": "2",
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
tool_call=modified_call,
|
||||
tool=another_tool,
|
||||
state={"count": 2},
|
||||
)
|
||||
|
||||
assert new_request.tool_call["name"] == "another_tool"
|
||||
assert new_request.tool.name == "another_tool"
|
||||
assert new_request.state == {"count": 2}
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["name"] == "test_tool"
|
||||
assert original_request.tool.name == "test_tool"
|
||||
assert original_request.state == {"count": 1}
|
||||
|
||||
def test_override_with_copy_pattern(self) -> None:
|
||||
"""Test common pattern of copying and modifying tool_call."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(value: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {value}"
|
||||
|
||||
original_call = {
|
||||
"name": "test_tool",
|
||||
"args": {"value": 5},
|
||||
"id": "call_123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
# Common pattern: copy tool_call and modify args
|
||||
modified_call = {**original_request.tool_call, "args": {"value": 10}}
|
||||
new_request = original_request.override(tool_call=modified_call)
|
||||
|
||||
assert new_request.tool_call["args"]["value"] == 10
|
||||
assert new_request.tool_call["id"] == "call_123"
|
||||
assert new_request.tool_call["name"] == "test_tool"
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["args"]["value"] == 5
|
||||
|
||||
def test_override_preserves_identity(self) -> None:
|
||||
"""Test that unchanged attributes maintain object identity."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
state = {"messages": []}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state=state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
new_request = original_request.override(tool_call=new_call)
|
||||
|
||||
# Unchanged objects should be the same instance
|
||||
assert new_request.tool is test_tool
|
||||
assert new_request.state is state
|
||||
|
||||
def test_override_chaining(self) -> None:
|
||||
"""Test chaining multiple override calls."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
call_2 = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
call_3 = {"name": "test_tool", "args": {"x": 15}, "id": "1", "type": "tool_call"}
|
||||
|
||||
final_request = (
|
||||
original_request.override(tool_call=call_2)
|
||||
.override(state={"count": 2})
|
||||
.override(tool_call=call_3)
|
||||
)
|
||||
|
||||
assert final_request.tool_call["args"]["x"] == 15
|
||||
assert final_request.state == {"count": 2}
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["args"]["x"] == 5
|
||||
assert original_request.state == {"count": 1}
|
||||
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response", **req.model_settings)
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
from typing import cast
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
||||
|
||||
class MockChatAnthropic:
|
||||
pass
|
||||
|
||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
in str(w[-1].message)
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Middleware for Anthropic models."""
|
||||
|
||||
from langchain_anthropic.middleware.prompt_caching import (
|
||||
AnthropicPromptCachingMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
]
|
||||
@@ -1,15 +1,28 @@
|
||||
"""Anthropic prompt caching middleware."""
|
||||
"""Anthropic prompt caching middleware.
|
||||
|
||||
Requires:
|
||||
- langchain: For agent middleware framework
|
||||
- langchain-anthropic: For ChatAnthropic model (already a dependency)
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
try:
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
|
||||
"This middleware is designed for use with LangChain agents. "
|
||||
"Install it with: pip install langchain"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
@@ -17,13 +30,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral",
|
||||
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
@@ -32,13 +47,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
||||
min_messages_to_cache: The minimum number of messages until the cache is used,
|
||||
default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
||||
"ignore" will ignore the unsupported model and continue without caching.
|
||||
"warn" will warn the user and continue without caching.
|
||||
"raise" will raise an error and stop the agent.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are
|
||||
supported.
|
||||
min_messages_to_cache: The minimum number of messages until the
|
||||
cache is used, default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an
|
||||
unsupported model is used. "ignore" will ignore the unsupported
|
||||
model and continue without caching. "warn" will warn the user
|
||||
and continue without caching. "raise" will raise an error and
|
||||
stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
@@ -53,18 +70,20 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
chat_anthropic_cls: type | None = ChatAnthropic
|
||||
except ImportError:
|
||||
ChatAnthropic = None # noqa: N806
|
||||
chat_anthropic_cls = None
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if ChatAnthropic is None:
|
||||
if chat_anthropic_cls is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, ChatAnthropic):
|
||||
elif not isinstance(request.model, chat_anthropic_cls):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
@@ -79,7 +98,9 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
@@ -41,6 +41,7 @@ test = [
|
||||
"vcrpy>=7.0.0,<8.0.0",
|
||||
"langchain-core",
|
||||
"langchain-tests",
|
||||
"langchain",
|
||||
]
|
||||
lint = ["ruff>=0.13.1,<0.14.0"]
|
||||
dev = ["langchain-core"]
|
||||
@@ -55,6 +56,7 @@ typing = [
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../../core", editable = true }
|
||||
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
langchain = { path = "../../langchain_v1", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Anthropic middleware."""
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Tests for Anthropic prompt caching middleware."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel):
|
||||
"""Fake model for testing middleware."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(
|
||||
result=[AIMessage(content="mock response")]
|
||||
)
|
||||
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Since we're in the langchain-anthropic package, ChatAnthropic is always
|
||||
# available. Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
1382
libs/partners/anthropic/uv.lock
generated
1382
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user