core, anthropic[patch]: support streaming tool calls when function has no arguments (#23915)

resolves https://github.com/langchain-ai/langchain/issues/23911

When an AIMessageChunk is instantiated, we attempt to parse tool calls
off of the tool_call_chunks.

Here we add a special-case to this parsing, where `""` will be parsed as
`{}`.

This is a reaction to how Anthropic streams tool calls in the case where
a function has no arguments:
```
{'id': 'toolu_01J8CgKcuUVrMqfTQWPYh64r', 'input': {}, 'name': 'magic_function', 'type': 'tool_use', 'index': 1}
{'partial_json': '', 'type': 'tool_use', 'index': 1}
```
The `partial_json` does not accumulate to a valid json string-- most
other providers tend to emit `"{}"` in this case.
This commit is contained in:
ccurme 2024-07-05 14:57:41 -04:00 committed by GitHub
parent 902b57d107
commit 74c7198906
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 3 deletions

View File

@ -241,7 +241,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
if isinstance(args_, dict):
tool_calls.append(
ToolCall(

View File

@ -121,6 +121,12 @@ def test_message_chunks() -> None:
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
ai_msg_chunk = AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
)
assert ai_msg_chunk.tool_calls == [ToolCall(name="tool1", args={}, id="1")]
# Test token usage
left = AIMessageChunk(
content="",

View File

@ -34,6 +34,12 @@ class TestGroqMixtral(BaseTestGroq):
def test_structured_output(self, model: BaseChatModel) -> None:
super().test_structured_output(model)
@pytest.mark.xfail(
reason=("May pass arguments: {'properties': {}, 'type': 'object'}")
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
class TestGroqLlama(BaseTestGroq):
@property

View File

@ -2,6 +2,7 @@
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
@ -18,3 +19,7 @@ class TestTogetherStandard(ChatModelIntegrationTests):
@property
def chat_model_params(self) -> dict:
return {"model": "mistralai/Mistral-7B-Instruct-v0.1"}
@pytest.mark.xfail(reason=("May not call a tool."))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

View File

@ -8,6 +8,7 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
@ -28,7 +29,13 @@ def magic_function(input: int) -> int:
return input + 2
def _validate_tool_call_message(message: AIMessage) -> None:
@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5
def _validate_tool_call_message(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
@ -37,6 +44,15 @@ def _validate_tool_call_message(message: AIMessage) -> None:
assert tool_call["id"] is not None
def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
class ChatModelIntegrationTests(ChatModelTests):
def test_invoke(self, model: BaseChatModel) -> None:
result = model.invoke("Hello")
@ -131,7 +147,6 @@ class ChatModelIntegrationTests(ChatModelTests):
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
_validate_tool_call_message(result)
# Test stream
@ -141,6 +156,21 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)
def test_structured_output(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")