mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
Compare commits
8 Commits
langchain-
...
async_impl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fed37ddf96 | ||
|
|
9f6b660f00 | ||
|
|
00c2380019 | ||
|
|
c53ffe4b82 | ||
|
|
23f5b0cedf | ||
|
|
89e3a10cbd | ||
|
|
fad0e2dbd8 | ||
|
|
6c2f0eb67a |
@@ -12,7 +12,6 @@ from .model_call_limit import ModelCallLimitMiddleware
|
|||||||
from .model_fallback import ModelFallbackMiddleware
|
from .model_fallback import ModelFallbackMiddleware
|
||||||
from .pii import PIIDetectionError, PIIMiddleware
|
from .pii import PIIDetectionError, PIIMiddleware
|
||||||
from .planning import PlanningMiddleware
|
from .planning import PlanningMiddleware
|
||||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
|
||||||
from .summarization import SummarizationMiddleware
|
from .summarization import SummarizationMiddleware
|
||||||
from .tool_call_limit import ToolCallLimitMiddleware
|
from .tool_call_limit import ToolCallLimitMiddleware
|
||||||
from .tool_emulator import LLMToolEmulator
|
from .tool_emulator import LLMToolEmulator
|
||||||
@@ -33,8 +32,6 @@ from .types import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentMiddleware",
|
"AgentMiddleware",
|
||||||
"AgentState",
|
"AgentState",
|
||||||
# should move to langchain-anthropic if we decide to keep it
|
|
||||||
"AnthropicPromptCachingMiddleware",
|
|
||||||
"ClearToolUsesEdit",
|
"ClearToolUsesEdit",
|
||||||
"ContextEditingMiddleware",
|
"ContextEditingMiddleware",
|
||||||
"HumanInTheLoopMiddleware",
|
"HumanInTheLoopMiddleware",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ with any LangChain chat model.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
"""Apply context edits before invoking the model via handler (async version)."""
|
||||||
|
if not request.messages:
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
if self.token_count_method == "approximate": # noqa: S105
|
||||||
|
|
||||||
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
return count_tokens_approximately(messages)
|
||||||
|
else:
|
||||||
|
system_msg = (
|
||||||
|
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
return request.model.get_num_tokens_from_messages(
|
||||||
|
system_msg + list(messages), request.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
for edit in self.edits:
|
||||||
|
edit.apply(request.messages, count_tokens=count_tokens)
|
||||||
|
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClearToolUsesEdit",
|
"ClearToolUsesEdit",
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
|
|||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
"""Try fallback models in sequence on errors (async version).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Initial model request.
|
||||||
|
handler: Async callback to execute the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIMessage from successful model call.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If all models fail, re-raises last exception.
|
||||||
|
"""
|
||||||
|
# Try primary model first
|
||||||
|
last_exception: Exception
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Try fallback models
|
||||||
|
for fallback_model in self.models:
|
||||||
|
request.model = fallback_model
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, Annotated, Literal
|
from typing import TYPE_CHECKING, Annotated, Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
|
|||||||
else self.system_prompt
|
else self.system_prompt
|
||||||
)
|
)
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||||
|
request.system_prompt = (
|
||||||
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
|
if request.system_prompt
|
||||||
|
else self.system_prompt
|
||||||
|
)
|
||||||
|
return await handler(request)
|
||||||
|
|||||||
@@ -1,89 +0,0 @@
|
|||||||
"""Anthropic prompt caching middleware."""
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Literal
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
from langchain.agents.middleware.types import (
|
|
||||||
AgentMiddleware,
|
|
||||||
ModelCallResult,
|
|
||||||
ModelRequest,
|
|
||||||
ModelResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
||||||
"""Prompt Caching Middleware.
|
|
||||||
|
|
||||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
|
||||||
|
|
||||||
Learn more about Anthropic prompt caching
|
|
||||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
type: Literal["ephemeral"] = "ephemeral",
|
|
||||||
ttl: Literal["5m", "1h"] = "5m",
|
|
||||||
min_messages_to_cache: int = 0,
|
|
||||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the middleware with cache control settings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
type: The type of cache to use, only "ephemeral" is supported.
|
|
||||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
|
||||||
min_messages_to_cache: The minimum number of messages until the cache is used,
|
|
||||||
default is 0.
|
|
||||||
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
|
||||||
"ignore" will ignore the unsupported model and continue without caching.
|
|
||||||
"warn" will warn the user and continue without caching.
|
|
||||||
"raise" will raise an error and stop the agent.
|
|
||||||
"""
|
|
||||||
self.type = type
|
|
||||||
self.ttl = ttl
|
|
||||||
self.min_messages_to_cache = min_messages_to_cache
|
|
||||||
self.unsupported_model_behavior = unsupported_model_behavior
|
|
||||||
|
|
||||||
def wrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
|
||||||
) -> ModelCallResult:
|
|
||||||
"""Modify the model request to add cache control blocks."""
|
|
||||||
try:
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
|
||||||
except ImportError:
|
|
||||||
ChatAnthropic = None # noqa: N806
|
|
||||||
|
|
||||||
msg: str | None = None
|
|
||||||
|
|
||||||
if ChatAnthropic is None:
|
|
||||||
msg = (
|
|
||||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
|
||||||
"Anthropic models. "
|
|
||||||
"Please install langchain-anthropic."
|
|
||||||
)
|
|
||||||
elif not isinstance(request.model, ChatAnthropic):
|
|
||||||
msg = (
|
|
||||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
|
||||||
f"Anthropic models, not instances of {type(request.model)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if msg is not None:
|
|
||||||
if self.unsupported_model_behavior == "raise":
|
|
||||||
raise ValueError(msg)
|
|
||||||
if self.unsupported_model_behavior == "warn":
|
|
||||||
warn(msg, stacklevel=3)
|
|
||||||
else:
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
messages_count = (
|
|
||||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
|
||||||
)
|
|
||||||
if messages_count < self.min_messages_to_cache:
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
|
||||||
|
|
||||||
return handler(request)
|
|
||||||
@@ -233,3 +233,169 @@ def test_exclude_tools_prevents_clearing() -> None:
|
|||||||
|
|
||||||
def _fake_runtime() -> Runtime:
|
def _fake_runtime() -> Runtime:
|
||||||
return cast(Runtime, object())
|
return cast(Runtime, object())
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_edit_when_below_trigger_async() -> None:
|
||||||
|
"""Test async version of context editing with no edit when below trigger."""
|
||||||
|
tool_call_id = "call-1"
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||||
|
)
|
||||||
|
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
state, request = _make_state_and_request([ai_message, tool_message])
|
||||||
|
middleware = ContextEditingMiddleware(
|
||||||
|
edits=[ClearToolUsesEdit(trigger=50)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||||
|
return AIMessage(content="mock response")
|
||||||
|
|
||||||
|
# Call awrap_model_call which modifies the request
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# The request should have been modified in place
|
||||||
|
assert request.messages[0].content == ""
|
||||||
|
assert request.messages[1].content == "12345"
|
||||||
|
assert state["messages"] == request.messages
|
||||||
|
|
||||||
|
|
||||||
|
async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||||
|
"""Test async version of clearing tool outputs and inputs."""
|
||||||
|
tool_call_id = "call-2"
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "tool_call", "id": tool_call_id, "name": "search", "args": {"query": "foo"}}
|
||||||
|
],
|
||||||
|
tool_calls=[{"id": tool_call_id, "name": "search", "args": {"query": "foo"}}],
|
||||||
|
)
|
||||||
|
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
state, request = _make_state_and_request([ai_message, tool_message])
|
||||||
|
|
||||||
|
edit = ClearToolUsesEdit(
|
||||||
|
trigger=50,
|
||||||
|
clear_at_least=10,
|
||||||
|
clear_tool_inputs=True,
|
||||||
|
keep=0,
|
||||||
|
placeholder="[cleared output]",
|
||||||
|
)
|
||||||
|
middleware = ContextEditingMiddleware(edits=[edit])
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||||
|
return AIMessage(content="mock response")
|
||||||
|
|
||||||
|
# Call awrap_model_call which modifies the request
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
cleared_ai = request.messages[0]
|
||||||
|
cleared_tool = request.messages[1]
|
||||||
|
|
||||||
|
assert isinstance(cleared_tool, ToolMessage)
|
||||||
|
assert cleared_tool.content == "[cleared output]"
|
||||||
|
assert cleared_tool.response_metadata["context_editing"]["cleared"] is True
|
||||||
|
|
||||||
|
assert isinstance(cleared_ai, AIMessage)
|
||||||
|
assert cleared_ai.tool_calls[0]["args"] == {}
|
||||||
|
context_meta = cleared_ai.response_metadata.get("context_editing")
|
||||||
|
assert context_meta is not None
|
||||||
|
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||||
|
|
||||||
|
assert state["messages"] == request.messages
|
||||||
|
|
||||||
|
|
||||||
|
async def test_respects_keep_last_tool_results_async() -> None:
|
||||||
|
"""Test async version respects keep parameter for last tool results."""
|
||||||
|
conversation: list[AIMessage | ToolMessage] = []
|
||||||
|
edits = [
|
||||||
|
("call-a", "tool-output-a" * 5),
|
||||||
|
("call-b", "tool-output-b" * 5),
|
||||||
|
("call-c", "tool-output-c" * 5),
|
||||||
|
]
|
||||||
|
|
||||||
|
for call_id, text in edits:
|
||||||
|
conversation.append(
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation.append(ToolMessage(content=text, tool_call_id=call_id))
|
||||||
|
|
||||||
|
state, request = _make_state_and_request(conversation)
|
||||||
|
|
||||||
|
middleware = ContextEditingMiddleware(
|
||||||
|
edits=[
|
||||||
|
ClearToolUsesEdit(
|
||||||
|
trigger=50,
|
||||||
|
keep=1,
|
||||||
|
placeholder="[cleared]",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
token_count_method="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||||
|
return AIMessage(content="mock response")
|
||||||
|
|
||||||
|
# Call awrap_model_call which modifies the request
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
cleared_messages = [
|
||||||
|
msg
|
||||||
|
for msg in request.messages
|
||||||
|
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(cleared_messages) == 2
|
||||||
|
assert isinstance(request.messages[-1], ToolMessage)
|
||||||
|
assert request.messages[-1].content != "[cleared]"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||||
|
"""Test async version of excluding tools from clearing."""
|
||||||
|
search_call = "call-search"
|
||||||
|
calc_call = "call-calc"
|
||||||
|
|
||||||
|
state, request = _make_state_and_request(
|
||||||
|
[
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"id": search_call, "name": "search", "args": {"query": "foo"}}],
|
||||||
|
),
|
||||||
|
ToolMessage(content="search-results" * 20, tool_call_id=search_call),
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"id": calc_call, "name": "calculator", "args": {"a": 1, "b": 2}}],
|
||||||
|
),
|
||||||
|
ToolMessage(content="42", tool_call_id=calc_call),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware = ContextEditingMiddleware(
|
||||||
|
edits=[
|
||||||
|
ClearToolUsesEdit(
|
||||||
|
trigger=50,
|
||||||
|
clear_at_least=10,
|
||||||
|
keep=0,
|
||||||
|
exclude_tools=("search",),
|
||||||
|
placeholder="[cleared]",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||||
|
return AIMessage(content="mock response")
|
||||||
|
|
||||||
|
# Call awrap_model_call which modifies the request
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
search_tool = request.messages[1]
|
||||||
|
calc_tool = request.messages[3]
|
||||||
|
|
||||||
|
assert isinstance(search_tool, ToolMessage)
|
||||||
|
assert search_tool.content == "search-results" * 20
|
||||||
|
|
||||||
|
assert isinstance(calc_tool, ToolMessage)
|
||||||
|
assert calc_tool.content == "[cleared]"
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
|
|||||||
ModelCallLimitExceededError,
|
ModelCallLimitExceededError,
|
||||||
)
|
)
|
||||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
|
||||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||||
from langchain.agents.middleware.types import (
|
from langchain.agents.middleware.types import (
|
||||||
AgentMiddleware,
|
AgentMiddleware,
|
||||||
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
|||||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||||
|
|
||||||
|
|
||||||
# Tests for AnthropicPromptCachingMiddleware
|
|
||||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
|
||||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
|
||||||
# Test with custom values
|
|
||||||
middleware = AnthropicPromptCachingMiddleware(
|
|
||||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
|
||||||
)
|
|
||||||
assert middleware.type == "ephemeral"
|
|
||||||
assert middleware.ttl == "1h"
|
|
||||||
assert middleware.min_messages_to_cache == 5
|
|
||||||
|
|
||||||
# Test with default values
|
|
||||||
middleware = AnthropicPromptCachingMiddleware()
|
|
||||||
assert middleware.type == "ephemeral"
|
|
||||||
assert middleware.ttl == "5m"
|
|
||||||
assert middleware.min_messages_to_cache == 0
|
|
||||||
|
|
||||||
fake_request = ModelRequest(
|
|
||||||
model=FakeToolCallingModel(),
|
|
||||||
messages=[HumanMessage("Hello")],
|
|
||||||
system_prompt=None,
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state={"messages": [HumanMessage("Hello")]},
|
|
||||||
runtime=cast(Runtime, object()),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response", **req.model_settings)
|
|
||||||
|
|
||||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
# Check that model_settings were passed through via the request
|
|
||||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
|
||||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
fake_request = ModelRequest(
|
|
||||||
model=FakeToolCallingModel(),
|
|
||||||
messages=[HumanMessage("Hello")],
|
|
||||||
system_prompt=None,
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state={"messages": [HumanMessage("Hello")]},
|
|
||||||
runtime=cast(Runtime, object()),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
|
||||||
):
|
|
||||||
middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
|
|
||||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
|
||||||
|
|
||||||
class MockChatAnthropic:
|
|
||||||
pass
|
|
||||||
|
|
||||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
|
||||||
):
|
|
||||||
middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
|
|
||||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
|
||||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
assert len(w) == 1
|
|
||||||
assert (
|
|
||||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
|
||||||
in str(w[-1].message)
|
|
||||||
)
|
|
||||||
assert isinstance(result, AIMessage)
|
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
|
||||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
|
||||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
assert isinstance(result, AIMessage)
|
|
||||||
assert len(w) == 1
|
|
||||||
assert (
|
|
||||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
|
||||||
in str(w[-1].message)
|
|
||||||
)
|
|
||||||
|
|
||||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
|
||||||
|
|
||||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
assert isinstance(result, AIMessage)
|
|
||||||
|
|
||||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
|
||||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
|
||||||
assert isinstance(result, AIMessage)
|
|
||||||
|
|
||||||
|
|
||||||
# Tests for SummarizationMiddleware
|
# Tests for SummarizationMiddleware
|
||||||
def test_summarization_middleware_initialization() -> None:
|
def test_summarization_middleware_initialization() -> None:
|
||||||
"""Test SummarizationMiddleware initialization."""
|
"""Test SummarizationMiddleware initialization."""
|
||||||
|
|||||||
@@ -0,0 +1,220 @@
|
|||||||
|
"""Unit tests for ModelFallbackMiddleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_runtime() -> Runtime:
|
||||||
|
return cast(Runtime, object())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request() -> ModelRequest:
|
||||||
|
"""Create a minimal ModelRequest for testing."""
|
||||||
|
model = GenericFakeChatModel(messages=iter([AIMessage(content="primary")]))
|
||||||
|
return ModelRequest(
|
||||||
|
model=model,
|
||||||
|
system_prompt=None,
|
||||||
|
messages=[],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||||
|
runtime=_fake_runtime(),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_primary_model_succeeds() -> None:
|
||||||
|
"""Test that primary model is used when it succeeds."""
|
||||||
|
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||||
|
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
# Simulate successful model call
|
||||||
|
result = req.model.invoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "primary response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_on_primary_failure() -> None:
|
||||||
|
"""Test that fallback model is used when primary fails."""
|
||||||
|
|
||||||
|
class FailingPrimaryModel(GenericFakeChatModel):
|
||||||
|
def _generate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Primary model failed")
|
||||||
|
|
||||||
|
primary_model = FailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||||
|
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = req.model.invoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "fallback response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_fallbacks() -> None:
|
||||||
|
"""Test that multiple fallback models are tried in sequence."""
|
||||||
|
|
||||||
|
class FailingModel(GenericFakeChatModel):
|
||||||
|
def _generate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Model failed")
|
||||||
|
|
||||||
|
primary_model = FailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||||
|
fallback1 = FailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||||
|
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = req.model.invoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "fallback2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_models_fail() -> None:
|
||||||
|
"""Test that exception is raised when all models fail."""
|
||||||
|
|
||||||
|
class AlwaysFailingModel(GenericFakeChatModel):
|
||||||
|
def _generate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Model failed")
|
||||||
|
|
||||||
|
primary_model = AlwaysFailingModel(messages=iter([]))
|
||||||
|
fallback_model = AlwaysFailingModel(messages=iter([]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = req.model.invoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Model failed"):
|
||||||
|
middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Async Tests
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def test_primary_model_succeeds_async() -> None:
|
||||||
|
"""Test async version - primary model is used when it succeeds."""
|
||||||
|
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||||
|
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
# Simulate successful async model call
|
||||||
|
result = await req.model.ainvoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "primary response"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fallback_on_primary_failure_async() -> None:
|
||||||
|
"""Test async version - fallback model is used when primary fails."""
|
||||||
|
|
||||||
|
class AsyncFailingPrimaryModel(GenericFakeChatModel):
|
||||||
|
async def _agenerate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Primary model failed")
|
||||||
|
|
||||||
|
primary_model = AsyncFailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||||
|
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = await req.model.ainvoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "fallback response"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_multiple_fallbacks_async() -> None:
|
||||||
|
"""Test async version - multiple fallback models are tried in sequence."""
|
||||||
|
|
||||||
|
class AsyncFailingModel(GenericFakeChatModel):
|
||||||
|
async def _agenerate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Model failed")
|
||||||
|
|
||||||
|
primary_model = AsyncFailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||||
|
fallback1 = AsyncFailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||||
|
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = await req.model.ainvoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
response = await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.result[0].content == "fallback2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_all_models_fail_async() -> None:
|
||||||
|
"""Test async version - exception is raised when all models fail."""
|
||||||
|
|
||||||
|
class AsyncAlwaysFailingModel(GenericFakeChatModel):
|
||||||
|
async def _agenerate(self, messages, **kwargs):
|
||||||
|
raise ValueError("Model failed")
|
||||||
|
|
||||||
|
primary_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||||
|
fallback_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||||
|
|
||||||
|
middleware = ModelFallbackMiddleware(fallback_model)
|
||||||
|
request = _make_request()
|
||||||
|
request.model = primary_model
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
result = await req.model.ainvoke([])
|
||||||
|
return ModelResponse(result=[result])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Model failed"):
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
"""Unit tests for PlanningMiddleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from langchain.agents.middleware.planning import PlanningMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_runtime() -> Runtime:
|
||||||
|
return cast(Runtime, object())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(system_prompt: str | None = None) -> ModelRequest:
|
||||||
|
"""Create a minimal ModelRequest for testing."""
|
||||||
|
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
||||||
|
return ModelRequest(
|
||||||
|
model=model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages=[],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||||
|
runtime=_fake_runtime(),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adds_system_prompt_when_none_exists() -> None:
|
||||||
|
"""Test that middleware adds system prompt when request has none."""
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
request = _make_request(system_prompt=None)
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# System prompt should be set
|
||||||
|
assert request.system_prompt is not None
|
||||||
|
assert "write_todos" in request.system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_appends_to_existing_system_prompt() -> None:
|
||||||
|
"""Test that middleware appends to existing system prompt."""
|
||||||
|
existing_prompt = "You are a helpful assistant."
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
request = _make_request(system_prompt=existing_prompt)
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# System prompt should contain both
|
||||||
|
assert request.system_prompt is not None
|
||||||
|
assert existing_prompt in request.system_prompt
|
||||||
|
assert "write_todos" in request.system_prompt
|
||||||
|
assert request.system_prompt.startswith(existing_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_system_prompt() -> None:
|
||||||
|
"""Test that middleware uses custom system prompt."""
|
||||||
|
custom_prompt = "Custom planning instructions"
|
||||||
|
middleware = PlanningMiddleware(system_prompt=custom_prompt)
|
||||||
|
request = _make_request(system_prompt=None)
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
middleware.wrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# Should use custom prompt
|
||||||
|
assert request.system_prompt == custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_write_todos_tool() -> None:
|
||||||
|
"""Test that middleware registers the write_todos tool."""
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
|
||||||
|
# Should have one tool registered
|
||||||
|
assert len(middleware.tools) == 1
|
||||||
|
assert middleware.tools[0].name == "write_todos"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_tool_description() -> None:
|
||||||
|
"""Test that middleware uses custom tool description."""
|
||||||
|
custom_description = "Custom todo tool description"
|
||||||
|
middleware = PlanningMiddleware(tool_description=custom_description)
|
||||||
|
|
||||||
|
# Tool should use custom description
|
||||||
|
assert len(middleware.tools) == 1
|
||||||
|
assert middleware.tools[0].description == custom_description
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Async Tests
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def test_adds_system_prompt_when_none_exists_async() -> None:
|
||||||
|
"""Test async version - middleware adds system prompt when request has none."""
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
request = _make_request(system_prompt=None)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# System prompt should be set
|
||||||
|
assert request.system_prompt is not None
|
||||||
|
assert "write_todos" in request.system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def test_appends_to_existing_system_prompt_async() -> None:
|
||||||
|
"""Test async version - middleware appends to existing system prompt."""
|
||||||
|
existing_prompt = "You are a helpful assistant."
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
request = _make_request(system_prompt=existing_prompt)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# System prompt should contain both
|
||||||
|
assert request.system_prompt is not None
|
||||||
|
assert existing_prompt in request.system_prompt
|
||||||
|
assert "write_todos" in request.system_prompt
|
||||||
|
assert request.system_prompt.startswith(existing_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_custom_system_prompt_async() -> None:
|
||||||
|
"""Test async version - middleware uses custom system prompt."""
|
||||||
|
custom_prompt = "Custom planning instructions"
|
||||||
|
middleware = PlanningMiddleware(system_prompt=custom_prompt)
|
||||||
|
request = _make_request(system_prompt=None)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
# Should use custom prompt
|
||||||
|
assert request.system_prompt == custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def test_handler_called_with_modified_request_async() -> None:
|
||||||
|
"""Test async version - handler receives the modified request."""
|
||||||
|
middleware = PlanningMiddleware()
|
||||||
|
request = _make_request(system_prompt="Original")
|
||||||
|
handler_called = {"value": False}
|
||||||
|
received_prompt = {"value": None}
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
handler_called["value"] = True
|
||||||
|
received_prompt["value"] = req.system_prompt
|
||||||
|
return ModelResponse(result=[AIMessage(content="response")])
|
||||||
|
|
||||||
|
await middleware.awrap_model_call(request, mock_handler)
|
||||||
|
|
||||||
|
assert handler_called["value"]
|
||||||
|
assert received_prompt["value"] is not None
|
||||||
|
assert "Original" in received_prompt["value"]
|
||||||
|
assert "write_todos" in received_prompt["value"]
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
"""Middleware for Anthropic models."""
|
||||||
|
|
||||||
|
from langchain_anthropic.middleware.prompt_caching import (
|
||||||
|
AnthropicPromptCachingMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnthropicPromptCachingMiddleware",
|
||||||
|
]
|
||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""Anthropic prompt caching middleware.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- langchain: For agent middleware framework
|
||||||
|
- langchain-anthropic: For ChatAnthropic model (already a dependency)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Literal
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
ModelCallResult,
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
msg = (
|
||||||
|
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
|
||||||
|
"This middleware is designed for use with LangChain agents. "
|
||||||
|
"Install it with: pip install langchain"
|
||||||
|
)
|
||||||
|
raise ImportError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||||
|
"""Prompt Caching Middleware.
|
||||||
|
|
||||||
|
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||||
|
|
||||||
|
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
|
||||||
|
|
||||||
|
Learn more about Anthropic prompt caching
|
||||||
|
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
|
||||||
|
ttl: Literal["5m", "1h"] = "5m",
|
||||||
|
min_messages_to_cache: int = 0,
|
||||||
|
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the middleware with cache control settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: The type of cache to use, only "ephemeral" is supported.
|
||||||
|
ttl: The time to live for the cache, only "5m" and "1h" are
|
||||||
|
supported.
|
||||||
|
min_messages_to_cache: The minimum number of messages until the
|
||||||
|
cache is used, default is 0.
|
||||||
|
unsupported_model_behavior: The behavior to take when an
|
||||||
|
unsupported model is used. "ignore" will ignore the unsupported
|
||||||
|
model and continue without caching. "warn" will warn the user
|
||||||
|
and continue without caching. "raise" will raise an error and
|
||||||
|
stop the agent.
|
||||||
|
"""
|
||||||
|
self.type = type
|
||||||
|
self.ttl = ttl
|
||||||
|
self.min_messages_to_cache = min_messages_to_cache
|
||||||
|
self.unsupported_model_behavior = unsupported_model_behavior
|
||||||
|
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
"""Modify the model request to add cache control blocks."""
|
||||||
|
try:
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
chat_anthropic_cls: type | None = ChatAnthropic
|
||||||
|
except ImportError:
|
||||||
|
chat_anthropic_cls = None
|
||||||
|
|
||||||
|
msg: str | None = None
|
||||||
|
|
||||||
|
if chat_anthropic_cls is None:
|
||||||
|
msg = (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models. "
|
||||||
|
"Please install langchain-anthropic."
|
||||||
|
)
|
||||||
|
elif not isinstance(request.model, chat_anthropic_cls):
|
||||||
|
msg = (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
f"Anthropic models, not instances of {type(request.model)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg is not None:
|
||||||
|
if self.unsupported_model_behavior == "raise":
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.unsupported_model_behavior == "warn":
|
||||||
|
warn(msg, stacklevel=3)
|
||||||
|
else:
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
messages_count = (
|
||||||
|
len(request.messages) + 1
|
||||||
|
if request.system_prompt
|
||||||
|
else len(request.messages)
|
||||||
|
)
|
||||||
|
if messages_count < self.min_messages_to_cache:
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||||
|
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
"""Modify the model request to add cache control blocks (async version)."""
|
||||||
|
try:
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
chat_anthropic_cls: type | None = ChatAnthropic
|
||||||
|
except ImportError:
|
||||||
|
chat_anthropic_cls = None
|
||||||
|
|
||||||
|
msg: str | None = None
|
||||||
|
|
||||||
|
if chat_anthropic_cls is None:
|
||||||
|
msg = (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models. "
|
||||||
|
"Please install langchain-anthropic."
|
||||||
|
)
|
||||||
|
elif not isinstance(request.model, chat_anthropic_cls):
|
||||||
|
msg = (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
f"Anthropic models, not instances of {type(request.model)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg is not None:
|
||||||
|
if self.unsupported_model_behavior == "raise":
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.unsupported_model_behavior == "warn":
|
||||||
|
warn(msg, stacklevel=3)
|
||||||
|
else:
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
messages_count = (
|
||||||
|
len(request.messages) + 1
|
||||||
|
if request.system_prompt
|
||||||
|
else len(request.messages)
|
||||||
|
)
|
||||||
|
if messages_count < self.min_messages_to_cache:
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||||
|
|
||||||
|
return await handler(request)
|
||||||
@@ -41,6 +41,7 @@ test = [
|
|||||||
"vcrpy>=7.0.0,<8.0.0",
|
"vcrpy>=7.0.0,<8.0.0",
|
||||||
"langchain-core",
|
"langchain-core",
|
||||||
"langchain-tests",
|
"langchain-tests",
|
||||||
|
"langchain",
|
||||||
]
|
]
|
||||||
lint = ["ruff>=0.13.1,<0.14.0"]
|
lint = ["ruff>=0.13.1,<0.14.0"]
|
||||||
dev = ["langchain-core"]
|
dev = ["langchain-core"]
|
||||||
@@ -55,6 +56,7 @@ typing = [
|
|||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
langchain-core = { path = "../../core", editable = true }
|
langchain-core = { path = "../../core", editable = true }
|
||||||
langchain-tests = { path = "../../standard-tests", editable = true }
|
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||||
|
langchain = { path = "../../langchain_v1", editable = true }
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
disallow_untyped_defs = "True"
|
disallow_untyped_defs = "True"
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for Anthropic middleware."""
|
||||||
@@ -0,0 +1,238 @@
|
|||||||
|
"""Tests for Anthropic prompt caching middleware."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class FakeToolCallingModel(BaseChatModel):
|
||||||
|
"""Fake model for testing middleware."""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
messages_string = "-".join([str(m.content) for m in messages])
|
||||||
|
message = AIMessage(content=messages_string, id="0")
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Async top level call"""
|
||||||
|
messages_string = "-".join([str(m.content) for m in messages])
|
||||||
|
message = AIMessage(content=messages_string, id="0")
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-tool-call-model"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||||
|
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||||
|
# Test with custom values
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(
|
||||||
|
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||||
|
)
|
||||||
|
assert middleware.type == "ephemeral"
|
||||||
|
assert middleware.ttl == "1h"
|
||||||
|
assert middleware.min_messages_to_cache == 5
|
||||||
|
|
||||||
|
# Test with default values
|
||||||
|
middleware = AnthropicPromptCachingMiddleware()
|
||||||
|
assert middleware.type == "ephemeral"
|
||||||
|
assert middleware.ttl == "5m"
|
||||||
|
assert middleware.min_messages_to_cache == 0
|
||||||
|
|
||||||
|
fake_request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
messages=[HumanMessage("Hello")],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": [HumanMessage("Hello")]},
|
||||||
|
runtime=cast(Runtime, object()),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||||
|
|
||||||
|
middleware.wrap_model_call(fake_request, mock_handler)
|
||||||
|
# Check that model_settings were passed through via the request
|
||||||
|
assert fake_request.model_settings == {
|
||||||
|
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||||
|
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||||
|
fake_request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
messages=[HumanMessage("Hello")],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": [HumanMessage("Hello")]},
|
||||||
|
runtime=cast(Runtime, object()),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||||
|
|
||||||
|
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||||
|
|
||||||
|
# Since we're in the langchain-anthropic package, ChatAnthropic is always
|
||||||
|
# available. Test that it raises an error for unsupported model instances
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=(
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models, not instances of"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
middleware.wrap_model_call(fake_request, mock_handler)
|
||||||
|
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||||
|
|
||||||
|
# Test warn behavior for unsupported model instances
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
assert len(w) == 1
|
||||||
|
assert (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models, not instances of"
|
||||||
|
) in str(w[-1].message)
|
||||||
|
|
||||||
|
# Test ignore behavior
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||||
|
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_prompt_caching_middleware_async() -> None:
|
||||||
|
"""Test AnthropicPromptCachingMiddleware async path."""
|
||||||
|
# Test with custom values
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(
|
||||||
|
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
messages=[HumanMessage("Hello")] * 6,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": [HumanMessage("Hello")] * 6},
|
||||||
|
runtime=cast(Runtime, object()),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||||
|
|
||||||
|
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
# Check that model_settings were passed through via the request
|
||||||
|
assert fake_request.model_settings == {
|
||||||
|
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
|
||||||
|
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
|
||||||
|
fake_request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
messages=[HumanMessage("Hello")],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": [HumanMessage("Hello")]},
|
||||||
|
runtime=cast(Runtime, object()),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||||
|
|
||||||
|
# Test that it raises an error for unsupported model instances
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=(
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models, not instances of"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await middleware.awrap_model_call(fake_request, mock_handler)
|
||||||
|
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||||
|
|
||||||
|
# Test warn behavior for unsupported model instances
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
assert len(w) == 1
|
||||||
|
assert (
|
||||||
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||||
|
"Anthropic models, not instances of"
|
||||||
|
) in str(w[-1].message)
|
||||||
|
|
||||||
|
# Test ignore behavior
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||||
|
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
|
||||||
|
"""Test async path respects min_messages_to_cache."""
|
||||||
|
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
|
||||||
|
|
||||||
|
# Test with fewer messages than minimum
|
||||||
|
fake_request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
messages=[HumanMessage("Hello")] * 3,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": [HumanMessage("Hello")] * 3},
|
||||||
|
runtime=cast(Runtime, object()),
|
||||||
|
model_settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||||
|
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||||
|
|
||||||
|
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||||
|
assert isinstance(result, ModelResponse)
|
||||||
|
# Cache control should NOT be added when message count is below minimum
|
||||||
|
assert fake_request.model_settings == {}
|
||||||
1382
libs/partners/anthropic/uv.lock
generated
1382
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user