diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 08c3ff66ce2..93ec406fc5e 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -41,7 +41,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) -from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.ai import InputTokenDetails, UsageMetadata from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import ( JsonOutputKeyToolsParser, @@ -766,12 +766,7 @@ class ChatAnthropic(BaseChatModel): ) else: msg = AIMessage(content=content) - # Collect token usage - msg.usage_metadata = { - "input_tokens": data.usage.input_tokens, - "output_tokens": data.usage.output_tokens, - "total_tokens": data.usage.input_tokens + data.usage.output_tokens, - } + msg.usage_metadata = _create_usage_metadata(data.usage) return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, @@ -1182,14 +1177,10 @@ def _make_message_chunk_from_anthropic_event( message_chunk: Optional[AIMessageChunk] = None # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 if event.type == "message_start" and stream_usage: - input_tokens = event.message.usage.input_tokens + usage_metadata = _create_usage_metadata(event.message.usage) message_chunk = AIMessageChunk( content="" if coerce_content_to_string else [], - usage_metadata=UsageMetadata( - input_tokens=input_tokens, - output_tokens=0, - total_tokens=input_tokens, - ), + usage_metadata=usage_metadata, ) elif ( event.type == "content_block_start" @@ -1235,14 +1226,10 @@ def _make_message_chunk_from_anthropic_event( tool_call_chunks=[tool_call_chunk], # type: ignore ) elif event.type == "message_delta" and stream_usage: - output_tokens = event.usage.output_tokens + usage_metadata = _create_usage_metadata(event.usage) message_chunk = AIMessageChunk( content="", - usage_metadata=UsageMetadata( - input_tokens=0, - output_tokens=output_tokens, - total_tokens=output_tokens, - ), + usage_metadata=usage_metadata, response_metadata={ "stop_reason": event.delta.stop_reason, "stop_sequence": event.delta.stop_sequence, @@ -1257,3 +1244,21 @@ def _make_message_chunk_from_anthropic_event( @deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic") class ChatAnthropicMessages(ChatAnthropic): pass + + +def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata: + input_token_details: Dict = { + "cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None), + "cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None), + } + + input_tokens = getattr(anthropic_usage, "input_tokens", 0) + output_tokens = getattr(anthropic_usage, "output_tokens", 0) + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + input_token_details=InputTokenDetails( + **{k: v for k, v in input_token_details.items() if v is not None} + ), + ) diff --git a/libs/partners/anthropic/tests/integration_tests/test_standard.py b/libs/partners/anthropic/tests/integration_tests/test_standard.py index e439fffc74f..24f8766ad06 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_standard.py +++ b/libs/partners/anthropic/tests/integration_tests/test_standard.py @@ -1,12 +1,16 @@ """Standard LangChain interface tests""" -from typing import Type +from pathlib import Path +from typing import List, Literal, Type, cast from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_anthropic import ChatAnthropic +REPO_ROOT_DIR = Path(__file__).parents[5] + class TestAnthropicStandard(ChatModelIntegrationTests): @property @@ -28,3 +32,103 @@ class TestAnthropicStandard(ChatModelIntegrationTests): @property def supports_anthropic_inputs(self) -> bool: return True + + @property + def supported_usage_metadata_details( + self, + ) -> List[ + Literal[ + "audio_input", + "audio_output", + "reasoning_output", + "cache_read_input", + "cache_creation_input", + ] + ]: + return ["cache_read_input", "cache_creation_input"] + + def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage: + llm = ChatAnthropic( + model="claude-3-5-sonnet-20240620", # type: ignore[call-arg] + extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg] + ) + with open(REPO_ROOT_DIR / "README.md", "r") as f: + readme = f.read() + + input_ = f"""What's langchain? Here's the langchain README: + + {readme} + """ + return _invoke( + llm, + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": input_, + "cache_control": {"type": "ephemeral"}, + } + ], + } + ], + stream, + ) + + def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage: + llm = ChatAnthropic( + model="claude-3-5-sonnet-20240620", # type: ignore[call-arg] + extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg] + ) + with open(REPO_ROOT_DIR / "README.md", "r") as f: + readme = f.read() + + input_ = f"""What's langchain? Here's the langchain README: + + {readme} + """ + + # invoke twice so first invocation is cached + _invoke( + llm, + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": input_, + "cache_control": {"type": "ephemeral"}, + } + ], + } + ], + stream, + ) + return _invoke( + llm, + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": input_, + "cache_control": {"type": "ephemeral"}, + } + ], + } + ], + stream, + ) + + +def _invoke(llm: ChatAnthropic, input_: list, stream: bool) -> AIMessage: + if stream: + full = None + for chunk in llm.stream(input_): + full = full + chunk if full else chunk # type: ignore[operator] + return cast(AIMessage, full) + else: + return cast(AIMessage, llm.invoke(input_)) diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 781a3b6e747..8c9f908c19f 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -5,8 +5,11 @@ from typing import Any, Callable, Dict, Literal, Type, cast import pytest from anthropic.types import Message, TextBlock, Usage +from anthropic.types.beta.prompt_caching import ( + PromptCachingBetaMessage, + PromptCachingBetaUsage, +) from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage -from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool from pydantic import BaseModel, Field, SecretStr @@ -89,30 +92,49 @@ def test__format_output() -> None: usage=Usage(input_tokens=2, output_tokens=1), type="message", ) - expected = ChatResult( - generations=[ - ChatGeneration( - message=AIMessage( # type: ignore[misc] - "bar", - usage_metadata={ - "input_tokens": 2, - "output_tokens": 1, - "total_tokens": 3, - }, - ) - ), - ], - llm_output={ - "id": "foo", - "model": "baz", - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 2, "output_tokens": 1}, + expected = AIMessage( # type: ignore[misc] + "bar", + usage_metadata={ + "input_tokens": 2, + "output_tokens": 1, + "total_tokens": 3, + "input_token_details": {}, }, ) llm = ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg] actual = llm._format_output(anthropic_msg) - assert expected == actual + assert actual.generations[0].message == expected + + +def test__format_output_cached() -> None: + anthropic_msg = PromptCachingBetaMessage( + id="foo", + content=[TextBlock(type="text", text="bar")], + model="baz", + role="assistant", + stop_reason=None, + stop_sequence=None, + usage=PromptCachingBetaUsage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=3, + cache_read_input_tokens=4, + ), + type="message", + ) + expected = AIMessage( # type: ignore[misc] + "bar", + usage_metadata={ + "input_tokens": 2, + "output_tokens": 1, + "total_tokens": 3, + "input_token_details": {"cache_creation": 3, "cache_read": 4}, + }, + ) + + llm = ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg] + actual = llm._format_output(anthropic_msg) + assert actual.generations[0].message == expected def test__merge_messages() -> None: