mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
core, anthropic[patch]: support streaming tool calls when function has no arguments (#23915)
resolves https://github.com/langchain-ai/langchain/issues/23911 When an AIMessageChunk is instantiated, we attempt to parse tool calls off of the tool_call_chunks. Here we add a special-case to this parsing, where `""` will be parsed as `{}`. This is a reaction to how Anthropic streams tool calls in the case where a function has no arguments: ``` {'id': 'toolu_01J8CgKcuUVrMqfTQWPYh64r', 'input': {}, 'name': 'magic_function', 'type': 'tool_use', 'index': 1} {'partial_json': '', 'type': 'tool_use', 'index': 1} ``` The `partial_json` does not accumulate to a valid json string-- most other providers tend to emit `"{}"` in this case.
This commit is contained in:
parent
902b57d107
commit
74c7198906
@ -241,7 +241,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
invalid_tool_calls = []
|
||||
for chunk in values["tool_call_chunks"]:
|
||||
try:
|
||||
args_ = parse_partial_json(chunk["args"])
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
|
@ -121,6 +121,12 @@ def test_message_chunks() -> None:
|
||||
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
|
||||
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
|
||||
|
||||
ai_msg_chunk = AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
|
||||
)
|
||||
assert ai_msg_chunk.tool_calls == [ToolCall(name="tool1", args={}, id="1")]
|
||||
|
||||
# Test token usage
|
||||
left = AIMessageChunk(
|
||||
content="",
|
||||
|
@ -34,6 +34,12 @@ class TestGroqMixtral(BaseTestGroq):
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output(model)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("May pass arguments: {'properties': {}, 'type': 'object'}")
|
||||
)
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
|
||||
class TestGroqLlama(BaseTestGroq):
|
||||
@property
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||
@ -18,3 +19,7 @@ class TestTogetherStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "mistralai/Mistral-7B-Instruct-v0.1"}
|
||||
|
||||
@pytest.mark.xfail(reason=("May not call a tool."))
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
@ -28,7 +29,13 @@ def magic_function(input: int) -> int:
|
||||
return input + 2
|
||||
|
||||
|
||||
def _validate_tool_call_message(message: AIMessage) -> None:
|
||||
@tool
|
||||
def magic_function_no_args() -> int:
|
||||
"""Calculates a magic function."""
|
||||
return 5
|
||||
|
||||
|
||||
def _validate_tool_call_message(message: BaseMessage) -> None:
|
||||
assert isinstance(message, AIMessage)
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
@ -37,6 +44,15 @@ def _validate_tool_call_message(message: AIMessage) -> None:
|
||||
assert tool_call["id"] is not None
|
||||
|
||||
|
||||
def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
|
||||
assert isinstance(message, AIMessage)
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call["name"] == "magic_function_no_args"
|
||||
assert tool_call["args"] == {}
|
||||
assert tool_call["id"] is not None
|
||||
|
||||
|
||||
class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_invoke(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("Hello")
|
||||
@ -131,7 +147,6 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
assert isinstance(result, AIMessage)
|
||||
_validate_tool_call_message(result)
|
||||
|
||||
# Test stream
|
||||
@ -141,6 +156,21 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
model_with_tools = model.bind_tools([magic_function_no_args])
|
||||
query = "What is the value of magic_function()? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
_validate_tool_call_message_no_args(result)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message_no_args(full)
|
||||
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
Loading…
Reference in New Issue
Block a user