From 181a61982fa5d96e2b5f0001acddde73b51a6605 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Fri, 14 Jun 2024 09:14:43 -0700 Subject: [PATCH] anthropic[minor]: Adds streaming tool call support for Anthropic (#22687) Preserves string content chunks for non tool call requests for convenience. One thing - Anthropic events look like this: ``` RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start') RawContentBlockDeltaEvent(delta=TextDelta(text='\nThe', type='text_delta'), index=0, type='content_block_delta') RawContentBlockDeltaEvent(delta=TextDelta(text=' provide', type='text_delta'), index=0, type='content_block_delta') ... RawContentBlockStartEvent(content_block=ToolUseBlock(id='toolu_01GJ6x2ddcMG3psDNNe4eDqb', input={}, name='get_weather', type='tool_use'), index=1, type='content_block_start') RawContentBlockDeltaEvent(delta=InputJsonDelta(partial_json='', type='input_json_delta'), index=1, type='content_block_delta') ``` Note that `delta` has a `type` field. With this implementation, I'm dropping it because `merge_list` behavior will concatenate strings. We currently have `index` as a special field when merging lists, would it be worth adding `type` too? If so, what do we set as a context block chunk? `text` vs. `text_delta`/`tool_use` vs `input_json_delta`? CC @ccurme @efriis @baskaryan --- .../langchain_anthropic/chat_models.py | 139 ++++++++---------- .../integration_tests/test_chat_models.py | 102 +++++++++++-- 2 files changed, 156 insertions(+), 85 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ce0f3bb98a6..6b3250e847b 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1,4 +1,3 @@ -import json import os import re import warnings @@ -142,7 +141,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D [ { "role": _message_type_lookups[m.type], - "content": [_AnthropicMessageContent(text=m.content).dict()], + "content": [_AnthropicMessageContent(text=m.content).model_dump()], } for m in messages ] @@ -670,34 +669,13 @@ class ChatAnthropic(BaseChatModel): if stream_usage is None: stream_usage = self.stream_usage params = self._format_params(messages=messages, stop=stop, **kwargs) - if _tools_in_params(params): - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - message = result.generations[0].message - if isinstance(message, AIMessage) and message.tool_calls is not None: - tool_call_chunks = [ - { - "name": tool_call["name"], - "args": json.dumps(tool_call["args"]), - "id": tool_call["id"], - "index": idx, - } - for idx, tool_call in enumerate(message.tool_calls) - ] - message_chunk = AIMessageChunk( - content=message.content, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, - ) - yield ChatGenerationChunk(message=message_chunk) - else: - yield cast(ChatGenerationChunk, result.generations[0]) - return stream = self._client.messages.create(**params, stream=True) + coerce_content_to_string = not _tools_in_params(params) for event in stream: msg = _make_message_chunk_from_anthropic_event( - event, stream_usage=stream_usage + event, + stream_usage=stream_usage, + coerce_content_to_string=coerce_content_to_string, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -717,35 +695,13 @@ class ChatAnthropic(BaseChatModel): if stream_usage is None: stream_usage = self.stream_usage params = self._format_params(messages=messages, stop=stop, **kwargs) - if _tools_in_params(params): - warnings.warn("stream: Tool use is not yet supported in streaming mode.") - result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - message = result.generations[0].message - if isinstance(message, AIMessage) and message.tool_calls is not None: - tool_call_chunks = [ - { - "name": tool_call["name"], - "args": json.dumps(tool_call["args"]), - "id": tool_call["id"], - "index": idx, - } - for idx, tool_call in enumerate(message.tool_calls) - ] - message_chunk = AIMessageChunk( - content=message.content, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, - ) - yield ChatGenerationChunk(message=message_chunk) - else: - yield cast(ChatGenerationChunk, result.generations[0]) - return stream = await self._async_client.messages.create(**params, stream=True) + coerce_content_to_string = not _tools_in_params(params) async for event in stream: msg = _make_message_chunk_from_anthropic_event( - event, stream_usage=stream_usage + event, + stream_usage=stream_usage, + coerce_content_to_string=coerce_content_to_string, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -789,15 +745,10 @@ class ChatAnthropic(BaseChatModel): ) -> ChatResult: params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: - if _tools_in_params(params): - warnings.warn( - "stream: Tool use is not yet supported in streaming mode." - ) - else: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) data = self._client.messages.create(**params) return self._format_output(data, **kwargs) @@ -810,15 +761,10 @@ class ChatAnthropic(BaseChatModel): ) -> ChatResult: params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: - if _tools_in_params(params): - warnings.warn( - "stream: Tool use is not yet supported in streaming mode." - ) - else: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) data = await self._async_client.messages.create(**params) return self._format_output(data, **kwargs) @@ -1117,6 +1063,7 @@ def _make_message_chunk_from_anthropic_event( event: anthropic.types.RawMessageStreamEvent, *, stream_usage: bool = True, + coerce_content_to_string: bool, ) -> Optional[AIMessageChunk]: """Convert Anthropic event to AIMessageChunk. @@ -1124,20 +1071,60 @@ def _make_message_chunk_from_anthropic_event( we return None. """ message_chunk: Optional[AIMessageChunk] = None + # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 if event.type == "message_start" and stream_usage: input_tokens = event.message.usage.input_tokens message_chunk = AIMessageChunk( - content="", + content="" if coerce_content_to_string else [], usage_metadata=UsageMetadata( input_tokens=input_tokens, output_tokens=0, total_tokens=input_tokens, ), ) - # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 - elif event.type == "content_block_delta" and event.delta.type == "text_delta": - text = event.delta.text - message_chunk = AIMessageChunk(content=text) + elif ( + event.type == "content_block_start" + and event.content_block is not None + and event.content_block.type == "tool_use" + ): + if coerce_content_to_string: + warnings.warn("Received unexpected tool content block.") + content_block = event.content_block.model_dump() + content_block["index"] = event.index + tool_call_chunk = { + "index": event.index, + "id": event.content_block.id, + "name": event.content_block.name, + "args": "", + } + message_chunk = AIMessageChunk( + content=[content_block], + tool_call_chunks=[tool_call_chunk], # type: ignore + ) + elif event.type == "content_block_delta": + if event.delta.type == "text_delta": + if coerce_content_to_string: + text = event.delta.text + message_chunk = AIMessageChunk(content=text) + else: + content_block = event.delta.model_dump() + content_block["index"] = event.index + content_block["type"] = "text" + message_chunk = AIMessageChunk(content=[content_block]) + elif event.delta.type == "input_json_delta": + content_block = event.delta.model_dump() + content_block["index"] = event.index + content_block["type"] = "tool_use" + tool_call_chunk = { + "index": event.index, + "id": None, + "name": None, + "args": event.delta.partial_json, + } + message_chunk = AIMessageChunk( + content=[content_block], + tool_call_chunks=[tool_call_chunk], # type: ignore + ) elif event.type == "message_delta" and stream_usage: output_tokens = event.usage.output_tokens message_chunk = AIMessageChunk( diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 3ccda5ab3e8..198278fed87 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -146,6 +146,55 @@ async def test_abatch_tags() -> None: assert isinstance(token.content, str) +async def test_async_tool_use() -> None: + llm = ChatAnthropic( # type: ignore[call-arg] + model=MODEL_NAME, + ) + + llm_with_tools = llm.bind_tools( + [ + { + "name": "get_weather", + "description": "Get weather report for a city", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + } + ] + ) + response = await llm_with_tools.ainvoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert isinstance(response.content, list) + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + # Test streaming + first = True + chunks = [] # type: ignore + async for chunk in llm_with_tools.astream( + "what's the weather in san francisco, ca" + ): + chunks = chunks + [chunk] + if first: + gathered = chunk + first = False + else: + gathered = gathered + chunk # type: ignore + assert len(chunks) > 1 + assert isinstance(gathered, AIMessageChunk) + assert isinstance(gathered.tool_call_chunks, list) + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert tool_call_chunk["name"] == "get_weather" + assert isinstance(tool_call_chunk["args"], str) + assert "location" in json.loads(tool_call_chunk["args"]) + + def test_batch() -> None: """Test batch tokens from ChatAnthropicMessages.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] @@ -313,7 +362,7 @@ async def test_astreaming() -> None: def test_tool_use() -> None: llm = ChatAnthropic( # type: ignore[call-arg] - model="claude-3-opus-20240229", + model=MODEL_NAME, ) llm_with_tools = llm.bind_tools( @@ -339,20 +388,55 @@ def test_tool_use() -> None: assert "location" in tool_call["args"] # Test streaming + input = "how are you? what's the weather in san francisco, ca" first = True - for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"): + chunks = [] # type: ignore + for chunk in llm_with_tools.stream(input): + chunks = chunks + [chunk] if first: gathered = chunk first = False else: gathered = gathered + chunk # type: ignore + assert len(chunks) > 1 + assert isinstance(gathered.content, list) + assert len(gathered.content) == 2 + tool_use_block = None + for content_block in gathered.content: + assert isinstance(content_block, dict) + if content_block["type"] == "tool_use": + tool_use_block = content_block + break + assert tool_use_block is not None + assert tool_use_block["name"] == "get_weather" + assert "location" in json.loads(tool_use_block["partial_json"]) assert isinstance(gathered, AIMessageChunk) - assert isinstance(gathered.tool_call_chunks, list) - assert len(gathered.tool_call_chunks) == 1 - tool_call_chunk = gathered.tool_call_chunks[0] - assert tool_call_chunk["name"] == "get_weather" - assert isinstance(tool_call_chunk["args"], str) - assert "location" in json.loads(tool_call_chunk["args"]) + assert isinstance(gathered.tool_calls, list) + assert len(gathered.tool_calls) == 1 + tool_call = gathered.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + assert tool_call["id"] is not None + + # Test passing response back to model + stream = llm_with_tools.stream( + [ + input, + gathered, + ToolMessage(content="sunny and warm", tool_call_id=tool_call["id"]), + ] + ) + chunks = [] # type: ignore + first = True + for chunk in stream: + chunks = chunks + [chunk] + if first: + gathered = chunk + first = False + else: + gathered = gathered + chunk # type: ignore + assert len(chunks) > 1 def test_anthropic_with_empty_text_block() -> None: @@ -428,7 +512,7 @@ class GetWeather(BaseModel): @pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"]) def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None: chat_model = ChatAnthropic( # type: ignore[call-arg] - model="claude-3-sonnet-20240229", + model=MODEL_NAME, ) chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice) response = chat_model_with_tools.invoke("what's the weather in ny and la")