mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
anthropic: stream token usage (#20180)
open to other ideas <img width="1181" alt="Screenshot 2024-04-08 at 5 34 08 PM" src="https://github.com/langchain-ai/langchain/assets/22008038/03eb11c4-5eb5-43e3-9109-a13f76098fa4"> --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
e0e40f3f63
commit
0d495f3f63
@ -43,6 +43,7 @@ from langchain_core.messages import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
@ -653,14 +654,20 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
)
|
)
|
||||||
yield ChatGenerationChunk(message=message_chunk)
|
yield ChatGenerationChunk(message=message_chunk)
|
||||||
else:
|
else:
|
||||||
yield cast(ChatGenerationChunk, result.generations[0])
|
yield cast(ChatGenerationChunk, result.generations[0])
|
||||||
return
|
return
|
||||||
|
full_generation_info: dict = {}
|
||||||
with self._client.messages.stream(**params) as stream:
|
with self._client.messages.stream(**params) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
chunk, full_generation_info = _make_chat_generation_chunk(
|
||||||
|
text,
|
||||||
|
stream.current_message_snapshot.model_dump(),
|
||||||
|
full_generation_info,
|
||||||
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(text, chunk=chunk)
|
run_manager.on_llm_new_token(text, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -692,14 +699,20 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
)
|
)
|
||||||
yield ChatGenerationChunk(message=message_chunk)
|
yield ChatGenerationChunk(message=message_chunk)
|
||||||
else:
|
else:
|
||||||
yield cast(ChatGenerationChunk, result.generations[0])
|
yield cast(ChatGenerationChunk, result.generations[0])
|
||||||
return
|
return
|
||||||
|
full_generation_info: dict = {}
|
||||||
async with self._async_client.messages.stream(**params) as stream:
|
async with self._async_client.messages.stream(**params) as stream:
|
||||||
async for text in stream.text_stream:
|
async for text in stream.text_stream:
|
||||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
chunk, full_generation_info = _make_chat_generation_chunk(
|
||||||
|
text,
|
||||||
|
stream.current_message_snapshot.model_dump(),
|
||||||
|
full_generation_info,
|
||||||
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(text, chunk=chunk)
|
await run_manager.on_llm_new_token(text, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -1064,6 +1077,59 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chat_generation_chunk(
|
||||||
|
text: str, message_dump: dict, full_generation_info: dict
|
||||||
|
) -> Tuple[ChatGenerationChunk, dict]:
|
||||||
|
"""Collect metadata and make ChatGenerationChunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text of the message chunk
|
||||||
|
message_dump: dict with metadata of the message chunk
|
||||||
|
full_generation_info: dict collecting metadata for full stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with ChatGenerationChunk and updated full_generation_info
|
||||||
|
"""
|
||||||
|
generation_info = {}
|
||||||
|
usage_metadata: Optional[UsageMetadata] = None
|
||||||
|
for k, v in message_dump.items():
|
||||||
|
if k in ("content", "role", "type") or (
|
||||||
|
k in full_generation_info and k not in ("usage", "stop_reason")
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
elif k == "usage":
|
||||||
|
input_tokens = v.get("input_tokens", 0)
|
||||||
|
output_tokens = v.get("output_tokens", 0)
|
||||||
|
if "usage" not in full_generation_info:
|
||||||
|
full_generation_info[k] = v
|
||||||
|
usage_metadata = UsageMetadata(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seen_input_tokens = full_generation_info[k].get("input_tokens", 0)
|
||||||
|
# Anthropic returns the same input token count for each message in a
|
||||||
|
# stream. To avoid double counting, we only count the input tokens
|
||||||
|
# once. After that, we set the input tokens to zero.
|
||||||
|
new_input_tokens = 0 if seen_input_tokens else input_tokens
|
||||||
|
usage_metadata = UsageMetadata(
|
||||||
|
input_tokens=new_input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=new_input_tokens + output_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
full_generation_info[k] = v
|
||||||
|
generation_info[k] = v
|
||||||
|
return (
|
||||||
|
ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=text, usage_metadata=usage_metadata),
|
||||||
|
generation_info=generation_info,
|
||||||
|
),
|
||||||
|
full_generation_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
|
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
|
||||||
class ChatAnthropicMessages(ChatAnthropic):
|
class ChatAnthropicMessages(ChatAnthropic):
|
||||||
pass
|
pass
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test ChatAnthropic chat model."""
|
"""Test ChatAnthropic chat model."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
@ -9,6 +9,7 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
@ -28,16 +29,80 @@ def test_stream() -> None:
|
|||||||
"""Test streaming tokens from Anthropic."""
|
"""Test streaming tokens from Anthropic."""
|
||||||
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
||||||
|
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
chunks_with_input_token_counts = 0
|
||||||
for token in llm.stream("I'm Pickle Rick"):
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
full = token if full is None else full + token
|
||||||
|
assert isinstance(token, AIMessageChunk)
|
||||||
|
if token.usage_metadata is not None and token.usage_metadata.get(
|
||||||
|
"input_tokens"
|
||||||
|
):
|
||||||
|
chunks_with_input_token_counts += 1
|
||||||
|
if chunks_with_input_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with input token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
# check token usage is populated
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert full.usage_metadata["total_tokens"] > 0
|
||||||
|
assert (
|
||||||
|
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||||
|
== full.usage_metadata["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
"""Test streaming tokens from Anthropic."""
|
"""Test streaming tokens from Anthropic."""
|
||||||
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
||||||
|
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
chunks_with_input_token_counts = 0
|
||||||
async for token in llm.astream("I'm Pickle Rick"):
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
full = token if full is None else full + token
|
||||||
|
assert isinstance(token, AIMessageChunk)
|
||||||
|
if token.usage_metadata is not None and token.usage_metadata.get(
|
||||||
|
"input_tokens"
|
||||||
|
):
|
||||||
|
chunks_with_input_token_counts += 1
|
||||||
|
if chunks_with_input_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with input token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
# check token usage is populated
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert full.usage_metadata["total_tokens"] > 0
|
||||||
|
assert (
|
||||||
|
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||||
|
== full.usage_metadata["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check assumption that each chunk has identical input token counts.
|
||||||
|
# This assumption is baked into _make_chat_generation_chunk.
|
||||||
|
params: dict = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [{"role": "user", "content": "I'm Pickle Rick"}],
|
||||||
|
}
|
||||||
|
all_input_tokens = set()
|
||||||
|
async with llm._async_client.messages.stream(**params) as stream:
|
||||||
|
async for _ in stream.text_stream:
|
||||||
|
message_dump = stream.current_message_snapshot.model_dump()
|
||||||
|
if input_tokens := message_dump.get("usage", {}).get("input_tokens"):
|
||||||
|
assert input_tokens > 0
|
||||||
|
all_input_tokens.add(input_tokens)
|
||||||
|
assert len(all_input_tokens) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
@ -268,6 +333,17 @@ def test_tool_use() -> None:
|
|||||||
assert isinstance(tool_call_chunk["args"], str)
|
assert isinstance(tool_call_chunk["args"], str)
|
||||||
assert "location" in json.loads(tool_call_chunk["args"])
|
assert "location" in json.loads(tool_call_chunk["args"])
|
||||||
|
|
||||||
|
# Check usage metadata
|
||||||
|
assert gathered.usage_metadata is not None
|
||||||
|
assert gathered.usage_metadata["input_tokens"] > 0
|
||||||
|
assert gathered.usage_metadata["output_tokens"] > 0
|
||||||
|
assert gathered.usage_metadata["total_tokens"] > 0
|
||||||
|
assert (
|
||||||
|
gathered.usage_metadata["input_tokens"]
|
||||||
|
+ gathered.usage_metadata["output_tokens"]
|
||||||
|
== gathered.usage_metadata["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_with_empty_text_block() -> None:
|
def test_anthropic_with_empty_text_block() -> None:
|
||||||
"""Anthropic SDK can return an empty text block."""
|
"""Anthropic SDK can return an empty text block."""
|
||||||
|
Loading…
Reference in New Issue
Block a user