anthropic[minor]: Adds streaming tool call support for Anthropic (#22687)

Preserves string content chunks for non tool call requests for
convenience.

One thing - Anthropic events look like this:

```
RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start')
RawContentBlockDeltaEvent(delta=TextDelta(text='<thinking>\nThe', type='text_delta'), index=0, type='content_block_delta')
RawContentBlockDeltaEvent(delta=TextDelta(text=' provide', type='text_delta'), index=0, type='content_block_delta')
...
RawContentBlockStartEvent(content_block=ToolUseBlock(id='toolu_01GJ6x2ddcMG3psDNNe4eDqb', input={}, name='get_weather', type='tool_use'), index=1, type='content_block_start')
RawContentBlockDeltaEvent(delta=InputJsonDelta(partial_json='', type='input_json_delta'), index=1, type='content_block_delta')
```

Note that `delta` has a `type` field. With this implementation, I'm
dropping it because `merge_list` behavior will concatenate strings.

We currently have `index` as a special field when merging lists, would
it be worth adding `type` too?

If so, what do we set as a context block chunk? `text` vs.
`text_delta`/`tool_use` vs `input_json_delta`?

CC @ccurme @efriis @baskaryan
This commit is contained in:
Jacob Lee
2024-06-14 09:14:43 -07:00
committed by GitHub
parent f40b2c6f9d
commit 181a61982f
2 changed files with 156 additions and 85 deletions

View File

@@ -1,4 +1,3 @@
import json
import os
import re
import warnings
@@ -142,7 +141,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
[
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
"content": [_AnthropicMessageContent(text=m.content).model_dump()],
}
for m in messages
]
@@ -670,34 +669,13 @@ class ChatAnthropic(BaseChatModel):
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
if _tools_in_params(params):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
stream = self._client.messages.create(**params, stream=True)
coerce_content_to_string = not _tools_in_params(params)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event, stream_usage=stream_usage
event,
stream_usage=stream_usage,
coerce_content_to_string=coerce_content_to_string,
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
@@ -717,35 +695,13 @@ class ChatAnthropic(BaseChatModel):
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
if _tools_in_params(params):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
stream = await self._async_client.messages.create(**params, stream=True)
coerce_content_to_string = not _tools_in_params(params)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event, stream_usage=stream_usage
event,
stream_usage=stream_usage,
coerce_content_to_string=coerce_content_to_string,
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
@@ -789,15 +745,10 @@ class ChatAnthropic(BaseChatModel):
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if _tools_in_params(params):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self._client.messages.create(**params)
return self._format_output(data, **kwargs)
@@ -810,15 +761,10 @@ class ChatAnthropic(BaseChatModel):
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if _tools_in_params(params):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self._async_client.messages.create(**params)
return self._format_output(data, **kwargs)
@@ -1117,6 +1063,7 @@ def _make_message_chunk_from_anthropic_event(
event: anthropic.types.RawMessageStreamEvent,
*,
stream_usage: bool = True,
coerce_content_to_string: bool,
) -> Optional[AIMessageChunk]:
"""Convert Anthropic event to AIMessageChunk.
@@ -1124,20 +1071,60 @@ def _make_message_chunk_from_anthropic_event(
we return None.
"""
message_chunk: Optional[AIMessageChunk] = None
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
input_tokens = event.message.usage.input_tokens
message_chunk = AIMessageChunk(
content="",
content="" if coerce_content_to_string else [],
usage_metadata=UsageMetadata(
input_tokens=input_tokens,
output_tokens=0,
total_tokens=input_tokens,
),
)
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
elif event.type == "content_block_delta" and event.delta.type == "text_delta":
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type == "tool_use"
):
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump()
content_block["index"] = event.index
tool_call_chunk = {
"index": event.index,
"id": event.content_block.id,
"name": event.content_block.name,
"args": "",
}
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore
)
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
if coerce_content_to_string:
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "text"
message_chunk = AIMessageChunk(content=[content_block])
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "tool_use"
tool_call_chunk = {
"index": event.index,
"id": None,
"name": None,
"args": event.delta.partial_json,
}
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore
)
elif event.type == "message_delta" and stream_usage:
output_tokens = event.usage.output_tokens
message_chunk = AIMessageChunk(