mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-25 12:44:04 +00:00 
			
		
		
		
	anthropic[minor]: Adds streaming tool call support for Anthropic (#22687)
Preserves string content chunks for non tool call requests for
convenience.
One thing - Anthropic events look like this:
```
RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start')
RawContentBlockDeltaEvent(delta=TextDelta(text='<thinking>\nThe', type='text_delta'), index=0, type='content_block_delta')
RawContentBlockDeltaEvent(delta=TextDelta(text=' provide', type='text_delta'), index=0, type='content_block_delta')
...
RawContentBlockStartEvent(content_block=ToolUseBlock(id='toolu_01GJ6x2ddcMG3psDNNe4eDqb', input={}, name='get_weather', type='tool_use'), index=1, type='content_block_start')
RawContentBlockDeltaEvent(delta=InputJsonDelta(partial_json='', type='input_json_delta'), index=1, type='content_block_delta')
```
Note that `delta` has a `type` field. With this implementation, I'm
dropping it because `merge_list` behavior will concatenate strings.
We currently have `index` as a special field when merging lists, would
it be worth adding `type` too?
If so, what do we set as a context block chunk? `text` vs.
`text_delta`/`tool_use` vs `input_json_delta`?
CC @ccurme @efriis @baskaryan
			
			
This commit is contained in:
		| @@ -1,4 +1,3 @@ | |||||||
| import json |  | ||||||
| import os | import os | ||||||
| import re | import re | ||||||
| import warnings | import warnings | ||||||
| @@ -142,7 +141,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D | |||||||
|     [ |     [ | ||||||
|                 { |                 { | ||||||
|                     "role": _message_type_lookups[m.type], |                     "role": _message_type_lookups[m.type], | ||||||
|                     "content": [_AnthropicMessageContent(text=m.content).dict()], |                     "content": [_AnthropicMessageContent(text=m.content).model_dump()], | ||||||
|                 } |                 } | ||||||
|                 for m in messages |                 for m in messages | ||||||
|             ] |             ] | ||||||
| @@ -670,34 +669,13 @@ class ChatAnthropic(BaseChatModel): | |||||||
|         if stream_usage is None: |         if stream_usage is None: | ||||||
|             stream_usage = self.stream_usage |             stream_usage = self.stream_usage | ||||||
|         params = self._format_params(messages=messages, stop=stop, **kwargs) |         params = self._format_params(messages=messages, stop=stop, **kwargs) | ||||||
|         if _tools_in_params(params): |  | ||||||
|             result = self._generate( |  | ||||||
|                 messages, stop=stop, run_manager=run_manager, **kwargs |  | ||||||
|             ) |  | ||||||
|             message = result.generations[0].message |  | ||||||
|             if isinstance(message, AIMessage) and message.tool_calls is not None: |  | ||||||
|                 tool_call_chunks = [ |  | ||||||
|                     { |  | ||||||
|                         "name": tool_call["name"], |  | ||||||
|                         "args": json.dumps(tool_call["args"]), |  | ||||||
|                         "id": tool_call["id"], |  | ||||||
|                         "index": idx, |  | ||||||
|                     } |  | ||||||
|                     for idx, tool_call in enumerate(message.tool_calls) |  | ||||||
|                 ] |  | ||||||
|                 message_chunk = AIMessageChunk( |  | ||||||
|                     content=message.content, |  | ||||||
|                     tool_call_chunks=tool_call_chunks,  # type: ignore[arg-type] |  | ||||||
|                     usage_metadata=message.usage_metadata, |  | ||||||
|                 ) |  | ||||||
|                 yield ChatGenerationChunk(message=message_chunk) |  | ||||||
|             else: |  | ||||||
|                 yield cast(ChatGenerationChunk, result.generations[0]) |  | ||||||
|             return |  | ||||||
|         stream = self._client.messages.create(**params, stream=True) |         stream = self._client.messages.create(**params, stream=True) | ||||||
|  |         coerce_content_to_string = not _tools_in_params(params) | ||||||
|         for event in stream: |         for event in stream: | ||||||
|             msg = _make_message_chunk_from_anthropic_event( |             msg = _make_message_chunk_from_anthropic_event( | ||||||
|                 event, stream_usage=stream_usage |                 event, | ||||||
|  |                 stream_usage=stream_usage, | ||||||
|  |                 coerce_content_to_string=coerce_content_to_string, | ||||||
|             ) |             ) | ||||||
|             if msg is not None: |             if msg is not None: | ||||||
|                 chunk = ChatGenerationChunk(message=msg) |                 chunk = ChatGenerationChunk(message=msg) | ||||||
| @@ -717,35 +695,13 @@ class ChatAnthropic(BaseChatModel): | |||||||
|         if stream_usage is None: |         if stream_usage is None: | ||||||
|             stream_usage = self.stream_usage |             stream_usage = self.stream_usage | ||||||
|         params = self._format_params(messages=messages, stop=stop, **kwargs) |         params = self._format_params(messages=messages, stop=stop, **kwargs) | ||||||
|         if _tools_in_params(params): |  | ||||||
|             warnings.warn("stream: Tool use is not yet supported in streaming mode.") |  | ||||||
|             result = await self._agenerate( |  | ||||||
|                 messages, stop=stop, run_manager=run_manager, **kwargs |  | ||||||
|             ) |  | ||||||
|             message = result.generations[0].message |  | ||||||
|             if isinstance(message, AIMessage) and message.tool_calls is not None: |  | ||||||
|                 tool_call_chunks = [ |  | ||||||
|                     { |  | ||||||
|                         "name": tool_call["name"], |  | ||||||
|                         "args": json.dumps(tool_call["args"]), |  | ||||||
|                         "id": tool_call["id"], |  | ||||||
|                         "index": idx, |  | ||||||
|                     } |  | ||||||
|                     for idx, tool_call in enumerate(message.tool_calls) |  | ||||||
|                 ] |  | ||||||
|                 message_chunk = AIMessageChunk( |  | ||||||
|                     content=message.content, |  | ||||||
|                     tool_call_chunks=tool_call_chunks,  # type: ignore[arg-type] |  | ||||||
|                     usage_metadata=message.usage_metadata, |  | ||||||
|                 ) |  | ||||||
|                 yield ChatGenerationChunk(message=message_chunk) |  | ||||||
|             else: |  | ||||||
|                 yield cast(ChatGenerationChunk, result.generations[0]) |  | ||||||
|             return |  | ||||||
|         stream = await self._async_client.messages.create(**params, stream=True) |         stream = await self._async_client.messages.create(**params, stream=True) | ||||||
|  |         coerce_content_to_string = not _tools_in_params(params) | ||||||
|         async for event in stream: |         async for event in stream: | ||||||
|             msg = _make_message_chunk_from_anthropic_event( |             msg = _make_message_chunk_from_anthropic_event( | ||||||
|                 event, stream_usage=stream_usage |                 event, | ||||||
|  |                 stream_usage=stream_usage, | ||||||
|  |                 coerce_content_to_string=coerce_content_to_string, | ||||||
|             ) |             ) | ||||||
|             if msg is not None: |             if msg is not None: | ||||||
|                 chunk = ChatGenerationChunk(message=msg) |                 chunk = ChatGenerationChunk(message=msg) | ||||||
| @@ -789,11 +745,6 @@ class ChatAnthropic(BaseChatModel): | |||||||
|     ) -> ChatResult: |     ) -> ChatResult: | ||||||
|         params = self._format_params(messages=messages, stop=stop, **kwargs) |         params = self._format_params(messages=messages, stop=stop, **kwargs) | ||||||
|         if self.streaming: |         if self.streaming: | ||||||
|             if _tools_in_params(params): |  | ||||||
|                 warnings.warn( |  | ||||||
|                     "stream: Tool use is not yet supported in streaming mode." |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|             stream_iter = self._stream( |             stream_iter = self._stream( | ||||||
|                 messages, stop=stop, run_manager=run_manager, **kwargs |                 messages, stop=stop, run_manager=run_manager, **kwargs | ||||||
|             ) |             ) | ||||||
| @@ -810,11 +761,6 @@ class ChatAnthropic(BaseChatModel): | |||||||
|     ) -> ChatResult: |     ) -> ChatResult: | ||||||
|         params = self._format_params(messages=messages, stop=stop, **kwargs) |         params = self._format_params(messages=messages, stop=stop, **kwargs) | ||||||
|         if self.streaming: |         if self.streaming: | ||||||
|             if _tools_in_params(params): |  | ||||||
|                 warnings.warn( |  | ||||||
|                     "stream: Tool use is not yet supported in streaming mode." |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|             stream_iter = self._astream( |             stream_iter = self._astream( | ||||||
|                 messages, stop=stop, run_manager=run_manager, **kwargs |                 messages, stop=stop, run_manager=run_manager, **kwargs | ||||||
|             ) |             ) | ||||||
| @@ -1117,6 +1063,7 @@ def _make_message_chunk_from_anthropic_event( | |||||||
|     event: anthropic.types.RawMessageStreamEvent, |     event: anthropic.types.RawMessageStreamEvent, | ||||||
|     *, |     *, | ||||||
|     stream_usage: bool = True, |     stream_usage: bool = True, | ||||||
|  |     coerce_content_to_string: bool, | ||||||
| ) -> Optional[AIMessageChunk]: | ) -> Optional[AIMessageChunk]: | ||||||
|     """Convert Anthropic event to AIMessageChunk. |     """Convert Anthropic event to AIMessageChunk. | ||||||
|  |  | ||||||
| @@ -1124,20 +1071,60 @@ def _make_message_chunk_from_anthropic_event( | |||||||
|     we return None. |     we return None. | ||||||
|     """ |     """ | ||||||
|     message_chunk: Optional[AIMessageChunk] = None |     message_chunk: Optional[AIMessageChunk] = None | ||||||
|  |     # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py  # noqa: E501 | ||||||
|     if event.type == "message_start" and stream_usage: |     if event.type == "message_start" and stream_usage: | ||||||
|         input_tokens = event.message.usage.input_tokens |         input_tokens = event.message.usage.input_tokens | ||||||
|         message_chunk = AIMessageChunk( |         message_chunk = AIMessageChunk( | ||||||
|             content="", |             content="" if coerce_content_to_string else [], | ||||||
|             usage_metadata=UsageMetadata( |             usage_metadata=UsageMetadata( | ||||||
|                 input_tokens=input_tokens, |                 input_tokens=input_tokens, | ||||||
|                 output_tokens=0, |                 output_tokens=0, | ||||||
|                 total_tokens=input_tokens, |                 total_tokens=input_tokens, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|     # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py  # noqa: E501 |     elif ( | ||||||
|     elif event.type == "content_block_delta" and event.delta.type == "text_delta": |         event.type == "content_block_start" | ||||||
|  |         and event.content_block is not None | ||||||
|  |         and event.content_block.type == "tool_use" | ||||||
|  |     ): | ||||||
|  |         if coerce_content_to_string: | ||||||
|  |             warnings.warn("Received unexpected tool content block.") | ||||||
|  |         content_block = event.content_block.model_dump() | ||||||
|  |         content_block["index"] = event.index | ||||||
|  |         tool_call_chunk = { | ||||||
|  |             "index": event.index, | ||||||
|  |             "id": event.content_block.id, | ||||||
|  |             "name": event.content_block.name, | ||||||
|  |             "args": "", | ||||||
|  |         } | ||||||
|  |         message_chunk = AIMessageChunk( | ||||||
|  |             content=[content_block], | ||||||
|  |             tool_call_chunks=[tool_call_chunk],  # type: ignore | ||||||
|  |         ) | ||||||
|  |     elif event.type == "content_block_delta": | ||||||
|  |         if event.delta.type == "text_delta": | ||||||
|  |             if coerce_content_to_string: | ||||||
|                 text = event.delta.text |                 text = event.delta.text | ||||||
|                 message_chunk = AIMessageChunk(content=text) |                 message_chunk = AIMessageChunk(content=text) | ||||||
|  |             else: | ||||||
|  |                 content_block = event.delta.model_dump() | ||||||
|  |                 content_block["index"] = event.index | ||||||
|  |                 content_block["type"] = "text" | ||||||
|  |                 message_chunk = AIMessageChunk(content=[content_block]) | ||||||
|  |         elif event.delta.type == "input_json_delta": | ||||||
|  |             content_block = event.delta.model_dump() | ||||||
|  |             content_block["index"] = event.index | ||||||
|  |             content_block["type"] = "tool_use" | ||||||
|  |             tool_call_chunk = { | ||||||
|  |                 "index": event.index, | ||||||
|  |                 "id": None, | ||||||
|  |                 "name": None, | ||||||
|  |                 "args": event.delta.partial_json, | ||||||
|  |             } | ||||||
|  |             message_chunk = AIMessageChunk( | ||||||
|  |                 content=[content_block], | ||||||
|  |                 tool_call_chunks=[tool_call_chunk],  # type: ignore | ||||||
|  |             ) | ||||||
|     elif event.type == "message_delta" and stream_usage: |     elif event.type == "message_delta" and stream_usage: | ||||||
|         output_tokens = event.usage.output_tokens |         output_tokens = event.usage.output_tokens | ||||||
|         message_chunk = AIMessageChunk( |         message_chunk = AIMessageChunk( | ||||||
|   | |||||||
| @@ -146,6 +146,55 @@ async def test_abatch_tags() -> None: | |||||||
|         assert isinstance(token.content, str) |         assert isinstance(token.content, str) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_async_tool_use() -> None: | ||||||
|  |     llm = ChatAnthropic(  # type: ignore[call-arg] | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     llm_with_tools = llm.bind_tools( | ||||||
|  |         [ | ||||||
|  |             { | ||||||
|  |                 "name": "get_weather", | ||||||
|  |                 "description": "Get weather report for a city", | ||||||
|  |                 "input_schema": { | ||||||
|  |                     "type": "object", | ||||||
|  |                     "properties": {"location": {"type": "string"}}, | ||||||
|  |                 }, | ||||||
|  |             } | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  |     response = await llm_with_tools.ainvoke("what's the weather in san francisco, ca") | ||||||
|  |     assert isinstance(response, AIMessage) | ||||||
|  |     assert isinstance(response.content, list) | ||||||
|  |     assert isinstance(response.tool_calls, list) | ||||||
|  |     assert len(response.tool_calls) == 1 | ||||||
|  |     tool_call = response.tool_calls[0] | ||||||
|  |     assert tool_call["name"] == "get_weather" | ||||||
|  |     assert isinstance(tool_call["args"], dict) | ||||||
|  |     assert "location" in tool_call["args"] | ||||||
|  |  | ||||||
|  |     # Test streaming | ||||||
|  |     first = True | ||||||
|  |     chunks = []  # type: ignore | ||||||
|  |     async for chunk in llm_with_tools.astream( | ||||||
|  |         "what's the weather in san francisco, ca" | ||||||
|  |     ): | ||||||
|  |         chunks = chunks + [chunk] | ||||||
|  |         if first: | ||||||
|  |             gathered = chunk | ||||||
|  |             first = False | ||||||
|  |         else: | ||||||
|  |             gathered = gathered + chunk  # type: ignore | ||||||
|  |     assert len(chunks) > 1 | ||||||
|  |     assert isinstance(gathered, AIMessageChunk) | ||||||
|  |     assert isinstance(gathered.tool_call_chunks, list) | ||||||
|  |     assert len(gathered.tool_call_chunks) == 1 | ||||||
|  |     tool_call_chunk = gathered.tool_call_chunks[0] | ||||||
|  |     assert tool_call_chunk["name"] == "get_weather" | ||||||
|  |     assert isinstance(tool_call_chunk["args"], str) | ||||||
|  |     assert "location" in json.loads(tool_call_chunk["args"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_batch() -> None: | def test_batch() -> None: | ||||||
|     """Test batch tokens from ChatAnthropicMessages.""" |     """Test batch tokens from ChatAnthropicMessages.""" | ||||||
|     llm = ChatAnthropicMessages(model_name=MODEL_NAME)  # type: ignore[call-arg, call-arg] |     llm = ChatAnthropicMessages(model_name=MODEL_NAME)  # type: ignore[call-arg, call-arg] | ||||||
| @@ -313,7 +362,7 @@ async def test_astreaming() -> None: | |||||||
|  |  | ||||||
| def test_tool_use() -> None: | def test_tool_use() -> None: | ||||||
|     llm = ChatAnthropic(  # type: ignore[call-arg] |     llm = ChatAnthropic(  # type: ignore[call-arg] | ||||||
|         model="claude-3-opus-20240229", |         model=MODEL_NAME, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     llm_with_tools = llm.bind_tools( |     llm_with_tools = llm.bind_tools( | ||||||
| @@ -339,20 +388,55 @@ def test_tool_use() -> None: | |||||||
|     assert "location" in tool_call["args"] |     assert "location" in tool_call["args"] | ||||||
|  |  | ||||||
|     # Test streaming |     # Test streaming | ||||||
|  |     input = "how are you? what's the weather in san francisco, ca" | ||||||
|     first = True |     first = True | ||||||
|     for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"): |     chunks = []  # type: ignore | ||||||
|  |     for chunk in llm_with_tools.stream(input): | ||||||
|  |         chunks = chunks + [chunk] | ||||||
|         if first: |         if first: | ||||||
|             gathered = chunk |             gathered = chunk | ||||||
|             first = False |             first = False | ||||||
|         else: |         else: | ||||||
|             gathered = gathered + chunk  # type: ignore |             gathered = gathered + chunk  # type: ignore | ||||||
|  |     assert len(chunks) > 1 | ||||||
|  |     assert isinstance(gathered.content, list) | ||||||
|  |     assert len(gathered.content) == 2 | ||||||
|  |     tool_use_block = None | ||||||
|  |     for content_block in gathered.content: | ||||||
|  |         assert isinstance(content_block, dict) | ||||||
|  |         if content_block["type"] == "tool_use": | ||||||
|  |             tool_use_block = content_block | ||||||
|  |             break | ||||||
|  |     assert tool_use_block is not None | ||||||
|  |     assert tool_use_block["name"] == "get_weather" | ||||||
|  |     assert "location" in json.loads(tool_use_block["partial_json"]) | ||||||
|     assert isinstance(gathered, AIMessageChunk) |     assert isinstance(gathered, AIMessageChunk) | ||||||
|     assert isinstance(gathered.tool_call_chunks, list) |     assert isinstance(gathered.tool_calls, list) | ||||||
|     assert len(gathered.tool_call_chunks) == 1 |     assert len(gathered.tool_calls) == 1 | ||||||
|     tool_call_chunk = gathered.tool_call_chunks[0] |     tool_call = gathered.tool_calls[0] | ||||||
|     assert tool_call_chunk["name"] == "get_weather" |     assert tool_call["name"] == "get_weather" | ||||||
|     assert isinstance(tool_call_chunk["args"], str) |     assert isinstance(tool_call["args"], dict) | ||||||
|     assert "location" in json.loads(tool_call_chunk["args"]) |     assert "location" in tool_call["args"] | ||||||
|  |     assert tool_call["id"] is not None | ||||||
|  |  | ||||||
|  |     # Test passing response back to model | ||||||
|  |     stream = llm_with_tools.stream( | ||||||
|  |         [ | ||||||
|  |             input, | ||||||
|  |             gathered, | ||||||
|  |             ToolMessage(content="sunny and warm", tool_call_id=tool_call["id"]), | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  |     chunks = []  # type: ignore | ||||||
|  |     first = True | ||||||
|  |     for chunk in stream: | ||||||
|  |         chunks = chunks + [chunk] | ||||||
|  |         if first: | ||||||
|  |             gathered = chunk | ||||||
|  |             first = False | ||||||
|  |         else: | ||||||
|  |             gathered = gathered + chunk  # type: ignore | ||||||
|  |     assert len(chunks) > 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_anthropic_with_empty_text_block() -> None: | def test_anthropic_with_empty_text_block() -> None: | ||||||
| @@ -428,7 +512,7 @@ class GetWeather(BaseModel): | |||||||
| @pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"]) | @pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"]) | ||||||
| def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None: | def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None: | ||||||
|     chat_model = ChatAnthropic(  # type: ignore[call-arg] |     chat_model = ChatAnthropic(  # type: ignore[call-arg] | ||||||
|         model="claude-3-sonnet-20240229", |         model=MODEL_NAME, | ||||||
|     ) |     ) | ||||||
|     chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice) |     chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice) | ||||||
|     response = chat_model_with_tools.invoke("what's the weather in ny and la") |     response = chat_model_with_tools.invoke("what's the weather in ny and la") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user