anthropic[patch]: add usage_metadata details (#27087)

fixes https://github.com/langchain-ai/langchain/pull/27087
This commit is contained in:
Bagatur
2024-10-04 08:46:49 -07:00
committed by GitHub
parent e8e5d67a8d
commit 0495b7f441
3 changed files with 172 additions and 41 deletions

View File

@@ -41,7 +41,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.ai import InputTokenDetails, UsageMetadata
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
@@ -766,12 +766,7 @@ class ChatAnthropic(BaseChatModel):
)
else:
msg = AIMessage(content=content)
# Collect token usage
msg.usage_metadata = {
"input_tokens": data.usage.input_tokens,
"output_tokens": data.usage.output_tokens,
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
}
msg.usage_metadata = _create_usage_metadata(data.usage)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
@@ -1182,14 +1177,10 @@ def _make_message_chunk_from_anthropic_event(
message_chunk: Optional[AIMessageChunk] = None
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
input_tokens = event.message.usage.input_tokens
usage_metadata = _create_usage_metadata(event.message.usage)
message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=UsageMetadata(
input_tokens=input_tokens,
output_tokens=0,
total_tokens=input_tokens,
),
usage_metadata=usage_metadata,
)
elif (
event.type == "content_block_start"
@@ -1235,14 +1226,10 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=[tool_call_chunk], # type: ignore
)
elif event.type == "message_delta" and stream_usage:
output_tokens = event.usage.output_tokens
usage_metadata = _create_usage_metadata(event.usage)
message_chunk = AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=0,
output_tokens=output_tokens,
total_tokens=output_tokens,
),
usage_metadata=usage_metadata,
response_metadata={
"stop_reason": event.delta.stop_reason,
"stop_sequence": event.delta.stop_sequence,
@@ -1257,3 +1244,21 @@ def _make_message_chunk_from_anthropic_event(
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
input_token_details: Dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}
input_tokens = getattr(anthropic_usage, "input_tokens", 0)
output_tokens = getattr(anthropic_usage, "output_tokens", 0)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
input_token_details=InputTokenDetails(
**{k: v for k, v in input_token_details.items() if v is not None}
),
)