diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 2bf4c1af15c..d7e63c8c04d 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -43,6 +43,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import ( @@ -446,6 +447,24 @@ class ChatAnthropic(BaseChatModel): {'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36} + Message chunks containing token usage will be included during streaming by + default: + + .. code-block:: python + + stream = llm.stream(messages) + full = next(stream) + for chunk in stream: + full += chunk + full.usage_metadata + + .. code-block:: python + + {'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36} + + These can be disabled by setting ``stream_usage=False`` in the stream method, + or by setting ``stream_usage=False`` when initializing ChatAnthropic. + Response metadata .. code-block:: python @@ -513,6 +532,11 @@ class ChatAnthropic(BaseChatModel): streaming: bool = False """Whether to use streaming or not.""" + stream_usage: bool = True + """Whether to include usage metadata in streaming output. If True, additional + message chunks will be generated during the stream including usage metadata. + """ + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -636,8 +660,12 @@ class ChatAnthropic(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if stream_usage is None: + stream_usage = self.stream_usage params = self._format_params(messages=messages, stop=stop, **kwargs) if _tools_in_params(params): result = self._generate( @@ -657,16 +685,21 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return - with self._client.messages.stream(**params) as stream: - for text in stream.text_stream: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) - if run_manager: - run_manager.on_llm_new_token(text, chunk=chunk) + stream = self._client.messages.create(**params, stream=True) + for event in stream: + msg = _make_message_chunk_from_anthropic_event( + event, stream_usage=stream_usage + ) + if msg is not None: + chunk = ChatGenerationChunk(message=msg) + if run_manager and isinstance(msg.content, str): + run_manager.on_llm_new_token(msg.content, chunk=chunk) yield chunk async def _astream( @@ -674,8 +707,12 @@ class ChatAnthropic(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + if stream_usage is None: + stream_usage = self.stream_usage params = self._format_params(messages=messages, stop=stop, **kwargs) if _tools_in_params(params): warnings.warn("stream: Tool use is not yet supported in streaming mode.") @@ -696,16 +733,21 @@ class ChatAnthropic(BaseChatModel): message_chunk = AIMessageChunk( content=message.content, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=message.usage_metadata, ) yield ChatGenerationChunk(message=message_chunk) else: yield cast(ChatGenerationChunk, result.generations[0]) return - async with self._async_client.messages.stream(**params) as stream: - async for text in stream.text_stream: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) - if run_manager: - await run_manager.on_llm_new_token(text, chunk=chunk) + stream = await self._async_client.messages.create(**params, stream=True) + async for event in stream: + msg = _make_message_chunk_from_anthropic_event( + event, stream_usage=stream_usage + ) + if msg is not None: + chunk = ChatGenerationChunk(message=msg) + if run_manager and isinstance(msg.content, str): + await run_manager.on_llm_new_token(msg.content, chunk=chunk) yield chunk def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: @@ -1068,6 +1110,47 @@ def _lc_tool_calls_to_anthropic_tool_use_blocks( return blocks +def _make_message_chunk_from_anthropic_event( + event: anthropic.types.RawMessageStreamEvent, + *, + stream_usage: bool = True, +) -> Optional[AIMessageChunk]: + """Convert Anthropic event to AIMessageChunk. + + Note that not all events will result in a message chunk. In these cases + we return None. + """ + message_chunk: Optional[AIMessageChunk] = None + if event.type == "message_start" and stream_usage: + input_tokens = event.message.usage.input_tokens + message_chunk = AIMessageChunk( + content="", + usage_metadata=UsageMetadata( + input_tokens=input_tokens, + output_tokens=0, + total_tokens=input_tokens, + ), + ) + # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 + elif event.type == "content_block_delta" and event.delta.type == "text_delta": + text = event.delta.text + message_chunk = AIMessageChunk(content=text) + elif event.type == "message_delta" and stream_usage: + output_tokens = event.usage.output_tokens + message_chunk = AIMessageChunk( + content="", + usage_metadata=UsageMetadata( + input_tokens=0, + output_tokens=output_tokens, + total_tokens=output_tokens, + ), + ) + else: + pass + + return message_chunk + + @deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic") class ChatAnthropicMessages(ChatAnthropic): pass diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index cee2cf70cf8..3ccda5ab3e8 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test ChatAnthropic chat model.""" import json -from typing import List +from typing import List, Optional import pytest from langchain_core.callbacks import CallbackManager @@ -9,6 +9,7 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, HumanMessage, SystemMessage, ToolMessage, @@ -28,16 +29,101 @@ def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] + full: Optional[BaseMessageChunk] = None + chunks_with_input_token_counts = 0 + chunks_with_output_token_counts = 0 for token in llm.stream("I'm Pickle Rick"): assert isinstance(token.content, str) + full = token if full is None else full + token + assert isinstance(token, AIMessageChunk) + if token.usage_metadata is not None: + if token.usage_metadata.get("input_tokens"): + chunks_with_input_token_counts += 1 + elif token.usage_metadata.get("output_tokens"): + chunks_with_output_token_counts += 1 + if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with input or output token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) async def test_astream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] + full: Optional[BaseMessageChunk] = None + chunks_with_input_token_counts = 0 + chunks_with_output_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token.content, str) + full = token if full is None else full + token + assert isinstance(token, AIMessageChunk) + if token.usage_metadata is not None: + if token.usage_metadata.get("input_tokens"): + chunks_with_input_token_counts += 1 + elif token.usage_metadata.get("output_tokens"): + chunks_with_output_token_counts += 1 + if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with input or output token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + # check token usage is populated + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) + + # test usage metadata can be excluded + model = ChatAnthropic(model_name=MODEL_NAME, stream_usage=False) # type: ignore[call-arg] + async for token in model.astream("hi"): + assert isinstance(token, AIMessageChunk) + assert token.usage_metadata is None + # check we override with kwarg + model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg] + assert model.stream_usage + async for token in model.astream("hi", stream_usage=False): + assert isinstance(token, AIMessageChunk) + assert token.usage_metadata is None + + # Check expected raw API output + async_client = model._async_client + params: dict = { + "model": "claude-3-haiku-20240307", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}], + "temperature": 0.0, + } + stream = await async_client.messages.create(**params, stream=True) + async for event in stream: + if event.type == "message_start": + assert event.message.usage.input_tokens > 1 + # Note: this single output token included in message start event + # does not appear to contribute to overall output token counts. It + # is excluded from the total token count. + assert event.message.usage.output_tokens == 1 + elif event.type == "message_delta": + assert event.usage.output_tokens > 1 + else: + pass async def test_abatch() -> None: