diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py index f2b2c5cfb2b..789a2eb5192 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py @@ -4,7 +4,7 @@ import warnings from typing import Any, cast import pytest -from langchain.agents.middleware.types import ModelRequest +from langchain.agents.middleware.types import ModelRequest, ModelResponse from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage @@ -25,7 +25,7 @@ class FakeToolCallingModel(BaseChatModel): **kwargs: Any, ) -> ChatResult: """Top Level call""" - messages_string = "-".join([m.content for m in messages]) + messages_string = "-".join([str(m.content) for m in messages]) message = AIMessage(content=messages_string, id="0") return ChatResult(generations=[ChatGeneration(message=message)]) @@ -62,8 +62,10 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None: model_settings={}, ) - def mock_handler(req: ModelRequest) -> AIMessage: - return AIMessage(content="mock response", **req.model_settings) + def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse( + result=[AIMessage(content="mock response")] + ) middleware.wrap_model_call(fake_request, mock_handler) # Check that model_settings were passed through via the request @@ -88,8 +90,8 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None: middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise") - def mock_handler(req: ModelRequest) -> AIMessage: - return AIMessage(content="mock response") + def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse(result=[AIMessage(content="mock response")]) # Since we're in the langchain-anthropic package, ChatAnthropic is always # available. Test that it raises an error for unsupported model instances @@ -107,7 +109,7 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None: # Test warn behavior for unsupported model instances with warnings.catch_warnings(record=True) as w: result = middleware.wrap_model_call(fake_request, mock_handler) - assert isinstance(result, AIMessage) + assert isinstance(result, ModelResponse) assert len(w) == 1 assert ( "AnthropicPromptCachingMiddleware caching middleware only supports " @@ -117,4 +119,4 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None: # Test ignore behavior middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore") result = middleware.wrap_model_call(fake_request, mock_handler) - assert isinstance(result, AIMessage) + assert isinstance(result, ModelResponse)