diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 7929dfcba53..e1a99772d6a 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,9 +1,10 @@ +import json from abc import ABC, abstractmethod from typing import Type import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool @@ -115,3 +116,38 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result, AIMessage) assert isinstance(result.content, str) assert len(result.content) > 0 + + def test_tool_message( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + chat_model_has_tool_calling: bool, + ) -> None: + if not chat_model_has_tool_calling: + pytest.skip("Test requires tool calling.") + model = chat_model_class(**chat_model_params) + model_with_tools = model.bind_tools([my_adder_tool]) + function_name = "my_adder_tool" + function_args = {"a": "1", "b": "2"} + + messages = [ + HumanMessage(content="What is 1 + 2"), + AIMessage( + content="", + tool_calls=[ + { + "name": function_name, + "args": function_args, + "id": "abc123", + }, + ], + ), + ToolMessage( + name=function_name, + content=json.dumps({"result": 3}), + tool_call_id="abc123", + ), + ] + + result = model_with_tools.invoke(messages) + assert isinstance(result, AIMessage)