mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-07 15:36:30 +00:00
anthropic[minor]: Adds streaming tool call support for Anthropic (#22687)
Preserves string content chunks for non tool call requests for convenience. One thing - Anthropic events look like this: ``` RawContentBlockStartEvent(content_block=TextBlock(text='', type='text'), index=0, type='content_block_start') RawContentBlockDeltaEvent(delta=TextDelta(text='<thinking>\nThe', type='text_delta'), index=0, type='content_block_delta') RawContentBlockDeltaEvent(delta=TextDelta(text=' provide', type='text_delta'), index=0, type='content_block_delta') ... RawContentBlockStartEvent(content_block=ToolUseBlock(id='toolu_01GJ6x2ddcMG3psDNNe4eDqb', input={}, name='get_weather', type='tool_use'), index=1, type='content_block_start') RawContentBlockDeltaEvent(delta=InputJsonDelta(partial_json='', type='input_json_delta'), index=1, type='content_block_delta') ``` Note that `delta` has a `type` field. With this implementation, I'm dropping it because `merge_list` behavior will concatenate strings. We currently have `index` as a special field when merging lists, would it be worth adding `type` too? If so, what do we set as a context block chunk? `text` vs. `text_delta`/`tool_use` vs `input_json_delta`? CC @ccurme @efriis @baskaryan
This commit is contained in:
parent
f40b2c6f9d
commit
181a61982f
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
@ -142,7 +141,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": _message_type_lookups[m.type],
|
"role": _message_type_lookups[m.type],
|
||||||
"content": [_AnthropicMessageContent(text=m.content).dict()],
|
"content": [_AnthropicMessageContent(text=m.content).model_dump()],
|
||||||
}
|
}
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
@ -670,34 +669,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
if stream_usage is None:
|
if stream_usage is None:
|
||||||
stream_usage = self.stream_usage
|
stream_usage = self.stream_usage
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||||
if _tools_in_params(params):
|
|
||||||
result = self._generate(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
message = result.generations[0].message
|
|
||||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
|
||||||
tool_call_chunks = [
|
|
||||||
{
|
|
||||||
"name": tool_call["name"],
|
|
||||||
"args": json.dumps(tool_call["args"]),
|
|
||||||
"id": tool_call["id"],
|
|
||||||
"index": idx,
|
|
||||||
}
|
|
||||||
for idx, tool_call in enumerate(message.tool_calls)
|
|
||||||
]
|
|
||||||
message_chunk = AIMessageChunk(
|
|
||||||
content=message.content,
|
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
)
|
|
||||||
yield ChatGenerationChunk(message=message_chunk)
|
|
||||||
else:
|
|
||||||
yield cast(ChatGenerationChunk, result.generations[0])
|
|
||||||
return
|
|
||||||
stream = self._client.messages.create(**params, stream=True)
|
stream = self._client.messages.create(**params, stream=True)
|
||||||
|
coerce_content_to_string = not _tools_in_params(params)
|
||||||
for event in stream:
|
for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event, stream_usage=stream_usage
|
event,
|
||||||
|
stream_usage=stream_usage,
|
||||||
|
coerce_content_to_string=coerce_content_to_string,
|
||||||
)
|
)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
chunk = ChatGenerationChunk(message=msg)
|
chunk = ChatGenerationChunk(message=msg)
|
||||||
@ -717,35 +695,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
if stream_usage is None:
|
if stream_usage is None:
|
||||||
stream_usage = self.stream_usage
|
stream_usage = self.stream_usage
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||||
if _tools_in_params(params):
|
|
||||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
|
||||||
result = await self._agenerate(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
message = result.generations[0].message
|
|
||||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
|
||||||
tool_call_chunks = [
|
|
||||||
{
|
|
||||||
"name": tool_call["name"],
|
|
||||||
"args": json.dumps(tool_call["args"]),
|
|
||||||
"id": tool_call["id"],
|
|
||||||
"index": idx,
|
|
||||||
}
|
|
||||||
for idx, tool_call in enumerate(message.tool_calls)
|
|
||||||
]
|
|
||||||
message_chunk = AIMessageChunk(
|
|
||||||
content=message.content,
|
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
)
|
|
||||||
yield ChatGenerationChunk(message=message_chunk)
|
|
||||||
else:
|
|
||||||
yield cast(ChatGenerationChunk, result.generations[0])
|
|
||||||
return
|
|
||||||
stream = await self._async_client.messages.create(**params, stream=True)
|
stream = await self._async_client.messages.create(**params, stream=True)
|
||||||
|
coerce_content_to_string = not _tools_in_params(params)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event, stream_usage=stream_usage
|
event,
|
||||||
|
stream_usage=stream_usage,
|
||||||
|
coerce_content_to_string=coerce_content_to_string,
|
||||||
)
|
)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
chunk = ChatGenerationChunk(message=msg)
|
chunk = ChatGenerationChunk(message=msg)
|
||||||
@ -789,11 +745,6 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
if _tools_in_params(params):
|
|
||||||
warnings.warn(
|
|
||||||
"stream: Tool use is not yet supported in streaming mode."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
@ -810,11 +761,6 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
if _tools_in_params(params):
|
|
||||||
warnings.warn(
|
|
||||||
"stream: Tool use is not yet supported in streaming mode."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stream_iter = self._astream(
|
stream_iter = self._astream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
@ -1117,6 +1063,7 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
event: anthropic.types.RawMessageStreamEvent,
|
event: anthropic.types.RawMessageStreamEvent,
|
||||||
*,
|
*,
|
||||||
stream_usage: bool = True,
|
stream_usage: bool = True,
|
||||||
|
coerce_content_to_string: bool,
|
||||||
) -> Optional[AIMessageChunk]:
|
) -> Optional[AIMessageChunk]:
|
||||||
"""Convert Anthropic event to AIMessageChunk.
|
"""Convert Anthropic event to AIMessageChunk.
|
||||||
|
|
||||||
@ -1124,20 +1071,60 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
we return None.
|
we return None.
|
||||||
"""
|
"""
|
||||||
message_chunk: Optional[AIMessageChunk] = None
|
message_chunk: Optional[AIMessageChunk] = None
|
||||||
|
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||||
if event.type == "message_start" and stream_usage:
|
if event.type == "message_start" and stream_usage:
|
||||||
input_tokens = event.message.usage.input_tokens
|
input_tokens = event.message.usage.input_tokens
|
||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content="",
|
content="" if coerce_content_to_string else [],
|
||||||
usage_metadata=UsageMetadata(
|
usage_metadata=UsageMetadata(
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
output_tokens=0,
|
output_tokens=0,
|
||||||
total_tokens=input_tokens,
|
total_tokens=input_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
elif (
|
||||||
elif event.type == "content_block_delta" and event.delta.type == "text_delta":
|
event.type == "content_block_start"
|
||||||
|
and event.content_block is not None
|
||||||
|
and event.content_block.type == "tool_use"
|
||||||
|
):
|
||||||
|
if coerce_content_to_string:
|
||||||
|
warnings.warn("Received unexpected tool content block.")
|
||||||
|
content_block = event.content_block.model_dump()
|
||||||
|
content_block["index"] = event.index
|
||||||
|
tool_call_chunk = {
|
||||||
|
"index": event.index,
|
||||||
|
"id": event.content_block.id,
|
||||||
|
"name": event.content_block.name,
|
||||||
|
"args": "",
|
||||||
|
}
|
||||||
|
message_chunk = AIMessageChunk(
|
||||||
|
content=[content_block],
|
||||||
|
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||||
|
)
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
if event.delta.type == "text_delta":
|
||||||
|
if coerce_content_to_string:
|
||||||
text = event.delta.text
|
text = event.delta.text
|
||||||
message_chunk = AIMessageChunk(content=text)
|
message_chunk = AIMessageChunk(content=text)
|
||||||
|
else:
|
||||||
|
content_block = event.delta.model_dump()
|
||||||
|
content_block["index"] = event.index
|
||||||
|
content_block["type"] = "text"
|
||||||
|
message_chunk = AIMessageChunk(content=[content_block])
|
||||||
|
elif event.delta.type == "input_json_delta":
|
||||||
|
content_block = event.delta.model_dump()
|
||||||
|
content_block["index"] = event.index
|
||||||
|
content_block["type"] = "tool_use"
|
||||||
|
tool_call_chunk = {
|
||||||
|
"index": event.index,
|
||||||
|
"id": None,
|
||||||
|
"name": None,
|
||||||
|
"args": event.delta.partial_json,
|
||||||
|
}
|
||||||
|
message_chunk = AIMessageChunk(
|
||||||
|
content=[content_block],
|
||||||
|
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||||
|
)
|
||||||
elif event.type == "message_delta" and stream_usage:
|
elif event.type == "message_delta" and stream_usage:
|
||||||
output_tokens = event.usage.output_tokens
|
output_tokens = event.usage.output_tokens
|
||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
|
@ -146,6 +146,55 @@ async def test_abatch_tags() -> None:
|
|||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_tool_use() -> None:
|
||||||
|
llm = ChatAnthropic( # type: ignore[call-arg]
|
||||||
|
model=MODEL_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_with_tools = llm.bind_tools(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather report for a city",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"location": {"type": "string"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
response = await llm_with_tools.ainvoke("what's the weather in san francisco, ca")
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, list)
|
||||||
|
assert isinstance(response.tool_calls, list)
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
tool_call = response.tool_calls[0]
|
||||||
|
assert tool_call["name"] == "get_weather"
|
||||||
|
assert isinstance(tool_call["args"], dict)
|
||||||
|
assert "location" in tool_call["args"]
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
first = True
|
||||||
|
chunks = [] # type: ignore
|
||||||
|
async for chunk in llm_with_tools.astream(
|
||||||
|
"what's the weather in san francisco, ca"
|
||||||
|
):
|
||||||
|
chunks = chunks + [chunk]
|
||||||
|
if first:
|
||||||
|
gathered = chunk
|
||||||
|
first = False
|
||||||
|
else:
|
||||||
|
gathered = gathered + chunk # type: ignore
|
||||||
|
assert len(chunks) > 1
|
||||||
|
assert isinstance(gathered, AIMessageChunk)
|
||||||
|
assert isinstance(gathered.tool_call_chunks, list)
|
||||||
|
assert len(gathered.tool_call_chunks) == 1
|
||||||
|
tool_call_chunk = gathered.tool_call_chunks[0]
|
||||||
|
assert tool_call_chunk["name"] == "get_weather"
|
||||||
|
assert isinstance(tool_call_chunk["args"], str)
|
||||||
|
assert "location" in json.loads(tool_call_chunk["args"])
|
||||||
|
|
||||||
|
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
"""Test batch tokens from ChatAnthropicMessages."""
|
"""Test batch tokens from ChatAnthropicMessages."""
|
||||||
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg]
|
||||||
@ -313,7 +362,7 @@ async def test_astreaming() -> None:
|
|||||||
|
|
||||||
def test_tool_use() -> None:
|
def test_tool_use() -> None:
|
||||||
llm = ChatAnthropic( # type: ignore[call-arg]
|
llm = ChatAnthropic( # type: ignore[call-arg]
|
||||||
model="claude-3-opus-20240229",
|
model=MODEL_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools(
|
llm_with_tools = llm.bind_tools(
|
||||||
@ -339,20 +388,55 @@ def test_tool_use() -> None:
|
|||||||
assert "location" in tool_call["args"]
|
assert "location" in tool_call["args"]
|
||||||
|
|
||||||
# Test streaming
|
# Test streaming
|
||||||
|
input = "how are you? what's the weather in san francisco, ca"
|
||||||
first = True
|
first = True
|
||||||
for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"):
|
chunks = [] # type: ignore
|
||||||
|
for chunk in llm_with_tools.stream(input):
|
||||||
|
chunks = chunks + [chunk]
|
||||||
if first:
|
if first:
|
||||||
gathered = chunk
|
gathered = chunk
|
||||||
first = False
|
first = False
|
||||||
else:
|
else:
|
||||||
gathered = gathered + chunk # type: ignore
|
gathered = gathered + chunk # type: ignore
|
||||||
|
assert len(chunks) > 1
|
||||||
|
assert isinstance(gathered.content, list)
|
||||||
|
assert len(gathered.content) == 2
|
||||||
|
tool_use_block = None
|
||||||
|
for content_block in gathered.content:
|
||||||
|
assert isinstance(content_block, dict)
|
||||||
|
if content_block["type"] == "tool_use":
|
||||||
|
tool_use_block = content_block
|
||||||
|
break
|
||||||
|
assert tool_use_block is not None
|
||||||
|
assert tool_use_block["name"] == "get_weather"
|
||||||
|
assert "location" in json.loads(tool_use_block["partial_json"])
|
||||||
assert isinstance(gathered, AIMessageChunk)
|
assert isinstance(gathered, AIMessageChunk)
|
||||||
assert isinstance(gathered.tool_call_chunks, list)
|
assert isinstance(gathered.tool_calls, list)
|
||||||
assert len(gathered.tool_call_chunks) == 1
|
assert len(gathered.tool_calls) == 1
|
||||||
tool_call_chunk = gathered.tool_call_chunks[0]
|
tool_call = gathered.tool_calls[0]
|
||||||
assert tool_call_chunk["name"] == "get_weather"
|
assert tool_call["name"] == "get_weather"
|
||||||
assert isinstance(tool_call_chunk["args"], str)
|
assert isinstance(tool_call["args"], dict)
|
||||||
assert "location" in json.loads(tool_call_chunk["args"])
|
assert "location" in tool_call["args"]
|
||||||
|
assert tool_call["id"] is not None
|
||||||
|
|
||||||
|
# Test passing response back to model
|
||||||
|
stream = llm_with_tools.stream(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
gathered,
|
||||||
|
ToolMessage(content="sunny and warm", tool_call_id=tool_call["id"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
chunks = [] # type: ignore
|
||||||
|
first = True
|
||||||
|
for chunk in stream:
|
||||||
|
chunks = chunks + [chunk]
|
||||||
|
if first:
|
||||||
|
gathered = chunk
|
||||||
|
first = False
|
||||||
|
else:
|
||||||
|
gathered = gathered + chunk # type: ignore
|
||||||
|
assert len(chunks) > 1
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_with_empty_text_block() -> None:
|
def test_anthropic_with_empty_text_block() -> None:
|
||||||
@ -428,7 +512,7 @@ class GetWeather(BaseModel):
|
|||||||
@pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"])
|
@pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"])
|
||||||
def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None:
|
def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None:
|
||||||
chat_model = ChatAnthropic( # type: ignore[call-arg]
|
chat_model = ChatAnthropic( # type: ignore[call-arg]
|
||||||
model="claude-3-sonnet-20240229",
|
model=MODEL_NAME,
|
||||||
)
|
)
|
||||||
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice)
|
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice)
|
||||||
response = chat_model_with_tools.invoke("what's the weather in ny and la")
|
response = chat_model_with_tools.invoke("what's the weather in ny and la")
|
||||||
|
Loading…
Reference in New Issue
Block a user