diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index 03d85839876..349a8b59ca4 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -42,6 +42,7 @@ from langchain_core.messages import ( HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, ) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -150,6 +151,15 @@ def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage: if tool_calls is not None: additional_kwargs["tool_calls"] = tool_calls return AIMessage(content=content, additional_kwargs=additional_kwargs) + if role == "tool": + additional_kwargs = {} + if "name" in dct: + additional_kwargs["name"] = dct["name"] + return ToolMessage( + content=content, + tool_call_id=dct.get("tool_call_id"), # type: ignore[arg-type] + additional_kwargs=additional_kwargs, + ) return ChatMessage(role=role, content=content) # type: ignore[arg-type] @@ -171,6 +181,13 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id, + "name": message.name or message.additional_kwargs.get("name"), + } else: raise TypeError(f"Got unknown type '{message.__class__.__name__}'.") return message_dict diff --git a/libs/community/tests/unit_tests/chat_models/test_zhipuai.py b/libs/community/tests/unit_tests/chat_models/test_zhipuai.py index 5295b6f340b..41d3e468f51 100644 --- a/libs/community/tests/unit_tests/chat_models/test_zhipuai.py +++ b/libs/community/tests/unit_tests/chat_models/test_zhipuai.py @@ -1,8 +1,12 @@ """Test ZhipuAI Chat API wrapper""" import pytest +from langchain_core.messages import ToolMessage -from langchain_community.chat_models.zhipuai import ChatZhipuAI +from langchain_community.chat_models.zhipuai import ( + ChatZhipuAI, + _convert_message_to_dict, +) @pytest.mark.requires("httpx", "httpx_sse", "jwt") @@ -11,3 +15,15 @@ def test_zhipuai_model_param() -> None: assert llm.model_name == "foo" llm = ChatZhipuAI(api_key="test", model_name="foo") # type: ignore[call-arg] assert llm.model_name == "foo" + + +def test__convert_message_to_dict_with_tool() -> None: + message = ToolMessage(name="foo", content="bar", tool_call_id="abc123") + result = _convert_message_to_dict(message) + expected_output = { + "name": "foo", + "content": "bar", + "tool_call_id": "abc123", + "role": "tool", + } + assert result == expected_output