This commit is contained in:
Eugene Yurtsev
2025-10-13 14:34:36 -04:00
parent 89e3a10cbd
commit 23f5b0cedf

View File

@@ -4,7 +4,7 @@ import warnings
from typing import Any, cast
import pytest
from langchain.agents.middleware.types import ModelRequest
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
@@ -25,7 +25,7 @@ class FakeToolCallingModel(BaseChatModel):
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([m.content for m in messages])
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
@@ -62,8 +62,10 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
model_settings={},
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(
result=[AIMessage(content="mock response")]
)
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
@@ -88,8 +90,8 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Since we're in the langchain-anthropic package, ChatAnthropic is always
# available. Test that it raises an error for unsupported model instances
@@ -107,7 +109,7 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
@@ -117,4 +119,4 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)