diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 693d4545db8..2c50f71b86a 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -241,7 +241,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): invalid_tool_calls = [] for chunk in values["tool_call_chunks"]: try: - args_ = parse_partial_json(chunk["args"]) + args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} if isinstance(args_, dict): tool_calls.append( ToolCall( diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 9dfef59b5eb..277cd07f0fa 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -121,6 +121,12 @@ def test_message_chunks() -> None: assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk + ai_msg_chunk = AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)], + ) + assert ai_msg_chunk.tool_calls == [ToolCall(name="tool1", args={}, id="1")] + # Test token usage left = AIMessageChunk( content="", diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 08d3ded5926..c43eb30cf76 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -34,6 +34,12 @@ class TestGroqMixtral(BaseTestGroq): def test_structured_output(self, model: BaseChatModel) -> None: super().test_structured_output(model) + @pytest.mark.xfail( + reason=("May pass arguments: {'properties': {}, 'type': 'object'}") + ) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) + class TestGroqLlama(BaseTestGroq): @property diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index a66f1be78b7..b91f9e784c3 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -2,6 +2,7 @@ from typing import Type +import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found] ChatModelIntegrationTests, # type: ignore[import-not-found] @@ -18,3 +19,7 @@ class TestTogetherStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: return {"model": "mistralai/Mistral-7B-Instruct-v0.1"} + + @pytest.mark.xfail(reason=("May not call a tool.")) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 53ba4ab161d..a0099844362 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -8,6 +8,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, + BaseMessage, BaseMessageChunk, HumanMessage, SystemMessage, @@ -28,7 +29,13 @@ def magic_function(input: int) -> int: return input + 2 -def _validate_tool_call_message(message: AIMessage) -> None: +@tool +def magic_function_no_args() -> int: + """Calculates a magic function.""" + return 5 + + +def _validate_tool_call_message(message: BaseMessage) -> None: assert isinstance(message, AIMessage) assert len(message.tool_calls) == 1 tool_call = message.tool_calls[0] @@ -37,6 +44,15 @@ def _validate_tool_call_message(message: AIMessage) -> None: assert tool_call["id"] is not None +def _validate_tool_call_message_no_args(message: BaseMessage) -> None: + assert isinstance(message, AIMessage) + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call["name"] == "magic_function_no_args" + assert tool_call["args"] == {} + assert tool_call["id"] is not None + + class ChatModelIntegrationTests(ChatModelTests): def test_invoke(self, model: BaseChatModel) -> None: result = model.invoke("Hello") @@ -131,7 +147,6 @@ class ChatModelIntegrationTests(ChatModelTests): # Test invoke query = "What is the value of magic_function(3)? Use the tool." result = model_with_tools.invoke(query) - assert isinstance(result, AIMessage) _validate_tool_call_message(result) # Test stream @@ -141,6 +156,21 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(full, AIMessage) _validate_tool_call_message(full) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + if not self.has_tool_calling: + pytest.skip("Test requires tool calling.") + + model_with_tools = model.bind_tools([magic_function_no_args]) + query = "What is the value of magic_function()? Use the tool." + result = model_with_tools.invoke(query) + _validate_tool_call_message_no_args(result) + + full: Optional[BaseMessageChunk] = None + for chunk in model_with_tools.stream(query): + full = chunk if full is None else full + chunk # type: ignore + assert isinstance(full, AIMessage) + _validate_tool_call_message_no_args(full) + def test_structured_output(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.")