mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
anthropic: refactor streaming to use events api; add streaming usage metadata (#22628)
- Refactor streaming to use raw events; - Add `stream_usage` class attribute and kwarg to stream methods that, if True, will include separate chunks in the stream containing usage metadata. There are two ways to implement streaming with anthropic's python sdk. They have slight differences in how they surface usage metadata. 1. [Use helper functions](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#streaming-helpers). This is what we are doing now. ```python count = 1 with client.messages.stream(**params) as stream: for text in stream.text_stream: snapshot = stream.current_message_snapshot print(f"{count}: {snapshot.usage} -- {text}") count = count + 1 final_snapshot = stream.get_final_message() print(f"{count}: {final_snapshot.usage}") ``` ``` 1: Usage(input_tokens=8, output_tokens=1) -- Hello 2: Usage(input_tokens=8, output_tokens=1) -- ! 3: Usage(input_tokens=8, output_tokens=1) -- How 4: Usage(input_tokens=8, output_tokens=1) -- can 5: Usage(input_tokens=8, output_tokens=1) -- I 6: Usage(input_tokens=8, output_tokens=1) -- assist 7: Usage(input_tokens=8, output_tokens=1) -- you 8: Usage(input_tokens=8, output_tokens=1) -- today 9: Usage(input_tokens=8, output_tokens=1) -- ? 10: Usage(input_tokens=8, output_tokens=12) ``` To do this correctly, we need to emit a new chunk at the end of the stream containing the usage metadata. 2. [Handle raw events](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#streaming-responses) ```python stream = client.messages.create(**params, stream=True) count = 1 for event in stream: print(f"{count}: {event}") count = count + 1 ``` ``` 1: RawMessageStartEvent(message=Message(id='msg_01Vdyov2kADZTXqSKkfNJXcS', content=[], model='claude-3-haiku-20240307', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start') 2: RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start') 3: RawContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta') 4: RawContentBlockDeltaEvent(delta=TextDelta(text='!', type='text_delta'), index=0, type='content_block_delta') 5: RawContentBlockDeltaEvent(delta=TextDelta(text=' How', type='text_delta'), index=0, type='content_block_delta') 6: RawContentBlockDeltaEvent(delta=TextDelta(text=' can', type='text_delta'), index=0, type='content_block_delta') 7: RawContentBlockDeltaEvent(delta=TextDelta(text=' I', type='text_delta'), index=0, type='content_block_delta') 8: RawContentBlockDeltaEvent(delta=TextDelta(text=' assist', type='text_delta'), index=0, type='content_block_delta') 9: RawContentBlockDeltaEvent(delta=TextDelta(text=' you', type='text_delta'), index=0, type='content_block_delta') 10: RawContentBlockDeltaEvent(delta=TextDelta(text=' today', type='text_delta'), index=0, type='content_block_delta') 11: RawContentBlockDeltaEvent(delta=TextDelta(text='?', type='text_delta'), index=0, type='content_block_delta') 12: RawContentBlockStopEvent(index=0, type='content_block_stop') 13: RawMessageDeltaEvent(delta=Delta(stop_reason='end_turn', stop_sequence=None), type='message_delta', usage=MessageDeltaUsage(output_tokens=12)) 14: RawMessageStopEvent(type='message_stop') ``` Here we implement the second option, in part because it should make things easier when implementing streaming tool calls in the near future. This would add two new chunks to the stream-- one at the beginning and one at the end-- with blank content and containing usage metadata. We add kwargs to the stream methods and a class attribute allowing for this behavior to be toggled. I enabled it by default. If we merge this we can add the same kwargs / attribute to OpenAI. Usage: ```python from langchain_anthropic import ChatAnthropic model = ChatAnthropic( model="claude-3-haiku-20240307", temperature=0 ) full = None for chunk in model.stream("hi"): full = chunk if full is None else full + chunk print(chunk) print(f"\nFull: {full}") ``` ``` content='' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 8, 'output_tokens': 0, 'total_tokens': 8} content='Hello' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content='!' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' How' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' can' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' I' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' assist' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' you' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content=' today' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content='?' id='run-8a20843f-25c7-4025-ad72-9add395899e3' content='' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 0, 'output_tokens': 12, 'total_tokens': 12} Full: content='Hello! How can I assist you today?' id='run-8a20843f-25c7-4025-ad72-9add395899e3' usage_metadata={'input_tokens': 8, 'output_tokens': 12, 'total_tokens': 20} ```
This commit is contained in:
parent
235d91940d
commit
f32d57f6f0
@ -43,6 +43,7 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import (
|
||||
@ -446,6 +447,24 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
{'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}
|
||||
|
||||
Message chunks containing token usage will be included during streaming by
|
||||
default:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = llm.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
full.usage_metadata
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}
|
||||
|
||||
These can be disabled by setting ``stream_usage=False`` in the stream method,
|
||||
or by setting ``stream_usage=False`` when initializing ChatAnthropic.
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
@ -513,6 +532,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
streaming: bool = False
|
||||
"""Whether to use streaming or not."""
|
||||
|
||||
stream_usage: bool = True
|
||||
"""Whether to include usage metadata in streaming output. If True, additional
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
@ -636,8 +660,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
if stream_usage is None:
|
||||
stream_usage = self.stream_usage
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if _tools_in_params(params):
|
||||
result = self._generate(
|
||||
@ -657,16 +685,21 @@ class ChatAnthropic(BaseChatModel):
|
||||
message_chunk = AIMessageChunk(
|
||||
content=message.content,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=message.usage_metadata,
|
||||
)
|
||||
yield ChatGenerationChunk(message=message_chunk)
|
||||
else:
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
with self._client.messages.stream(**params) as stream:
|
||||
for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(text, chunk=chunk)
|
||||
stream = self._client.messages.create(**params, stream=True)
|
||||
for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event, stream_usage=stream_usage
|
||||
)
|
||||
if msg is not None:
|
||||
chunk = ChatGenerationChunk(message=msg)
|
||||
if run_manager and isinstance(msg.content, str):
|
||||
run_manager.on_llm_new_token(msg.content, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
@ -674,8 +707,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
if stream_usage is None:
|
||||
stream_usage = self.stream_usage
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if _tools_in_params(params):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
@ -696,16 +733,21 @@ class ChatAnthropic(BaseChatModel):
|
||||
message_chunk = AIMessageChunk(
|
||||
content=message.content,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=message.usage_metadata,
|
||||
)
|
||||
yield ChatGenerationChunk(message=message_chunk)
|
||||
else:
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
async with self._async_client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(text, chunk=chunk)
|
||||
stream = await self._async_client.messages.create(**params, stream=True)
|
||||
async for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event, stream_usage=stream_usage
|
||||
)
|
||||
if msg is not None:
|
||||
chunk = ChatGenerationChunk(message=msg)
|
||||
if run_manager and isinstance(msg.content, str):
|
||||
await run_manager.on_llm_new_token(msg.content, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
@ -1068,6 +1110,47 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks(
|
||||
return blocks
|
||||
|
||||
|
||||
def _make_message_chunk_from_anthropic_event(
|
||||
event: anthropic.types.RawMessageStreamEvent,
|
||||
*,
|
||||
stream_usage: bool = True,
|
||||
) -> Optional[AIMessageChunk]:
|
||||
"""Convert Anthropic event to AIMessageChunk.
|
||||
|
||||
Note that not all events will result in a message chunk. In these cases
|
||||
we return None.
|
||||
"""
|
||||
message_chunk: Optional[AIMessageChunk] = None
|
||||
if event.type == "message_start" and stream_usage:
|
||||
input_tokens = event.message.usage.input_tokens
|
||||
message_chunk = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
),
|
||||
)
|
||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||
elif event.type == "content_block_delta" and event.delta.type == "text_delta":
|
||||
text = event.delta.text
|
||||
message_chunk = AIMessageChunk(content=text)
|
||||
elif event.type == "message_delta" and stream_usage:
|
||||
output_tokens = event.usage.output_tokens
|
||||
message_chunk = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
return message_chunk
|
||||
|
||||
|
||||
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
|
||||
class ChatAnthropicMessages(ChatAnthropic):
|
||||
pass
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test ChatAnthropic chat model."""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
@ -9,6 +9,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
@ -28,16 +29,101 @@ def test_stream() -> None:
|
||||
"""Test streaming tokens from Anthropic."""
|
||||
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_input_token_counts = 0
|
||||
chunks_with_output_token_counts = 0
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full = token if full is None else full + token
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
if token.usage_metadata is not None:
|
||||
if token.usage_metadata.get("input_tokens"):
|
||||
chunks_with_input_token_counts += 1
|
||||
elif token.usage_metadata.get("output_tokens"):
|
||||
chunks_with_output_token_counts += 1
|
||||
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with input or output token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
assert (
|
||||
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||
== full.usage_metadata["total_tokens"]
|
||||
)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from Anthropic."""
|
||||
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_input_token_counts = 0
|
||||
chunks_with_output_token_counts = 0
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full = token if full is None else full + token
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
if token.usage_metadata is not None:
|
||||
if token.usage_metadata.get("input_tokens"):
|
||||
chunks_with_input_token_counts += 1
|
||||
elif token.usage_metadata.get("output_tokens"):
|
||||
chunks_with_output_token_counts += 1
|
||||
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with input or output token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
assert (
|
||||
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||
== full.usage_metadata["total_tokens"]
|
||||
)
|
||||
|
||||
# test usage metadata can be excluded
|
||||
model = ChatAnthropic(model_name=MODEL_NAME, stream_usage=False) # type: ignore[call-arg]
|
||||
async for token in model.astream("hi"):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert token.usage_metadata is None
|
||||
# check we override with kwarg
|
||||
model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg]
|
||||
assert model.stream_usage
|
||||
async for token in model.astream("hi", stream_usage=False):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert token.usage_metadata is None
|
||||
|
||||
# Check expected raw API output
|
||||
async_client = model._async_client
|
||||
params: dict = {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stream = await async_client.messages.create(**params, stream=True)
|
||||
async for event in stream:
|
||||
if event.type == "message_start":
|
||||
assert event.message.usage.input_tokens > 1
|
||||
# Note: this single output token included in message start event
|
||||
# does not appear to contribute to overall output token counts. It
|
||||
# is excluded from the total token count.
|
||||
assert event.message.usage.output_tokens == 1
|
||||
elif event.type == "message_delta":
|
||||
assert event.usage.output_tokens > 1
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user