diff --git a/libs/community/tests/integration_tests/chat_models/test_zhipuai.py b/libs/community/tests/integration_tests/chat_models/test_zhipuai.py index 5750705eab9..1f94712a259 100644 --- a/libs/community/tests/integration_tests/chat_models/test_zhipuai.py +++ b/libs/community/tests/integration_tests/chat_models/test_zhipuai.py @@ -1,8 +1,9 @@ """Test ZhipuAI Chat Model.""" from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.tools import tool from langchain_community.chat_models.zhipuai import ChatZhipuAI from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -71,3 +72,38 @@ def test_multiple_messages() -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +@tool +def add(a: int, b: int) -> int: + """Adds a and b.""" + return a + b + + +@tool +def multiply(a: int, b: int) -> int: + """Multiplies a and b.""" + return a * b + + +def test_tool_call() -> None: + """Test tool calling by ChatZhipuAI""" + chat = ChatZhipuAI(model="glm-4-long") # type: ignore[call-arg] + tools = [add, multiply] + chat_with_tools = chat.bind_tools(tools) + + query = "What is 3 * 12?" + messages = [HumanMessage(query)] + ai_msg = chat_with_tools.invoke(messages) + assert isinstance(ai_msg, AIMessage) + assert isinstance(ai_msg.tool_calls, list) + assert len(ai_msg.tool_calls) == 1 + tool_call = ai_msg.tool_calls[0] + assert "args" in tool_call + messages.append(ai_msg) # type: ignore[arg-type] + for tool_call in ai_msg.tool_calls: + selected_tool = {"add": add, "multiply": multiply}[tool_call["name"].lower()] + tool_output = selected_tool.invoke(tool_call["args"]) # type: ignore[attr-defined] + messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"])) # type: ignore[arg-type] + response = chat_with_tools.invoke(messages) + assert isinstance(response, AIMessage)