From 721f709dec4aea0bbc2db7a518783257fac5272d Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Mon, 22 Jul 2024 23:29:00 +0800 Subject: [PATCH] community: Improve QianfanChatEndpoint tool result to model (#24466) - **Description:** `QianfanChatEndpoint` When using tool result to answer questions, the content of the tool is required to be in Dict format. Of course, this can require users to return Dict format when calling the tool, but in order to be consistent with other Chat Models, I think such modifications are necessary. --- .../chat_models/baidu_qianfan_endpoint.py | 17 +++++++- .../chat_models/test_baiduqianfan.py | 39 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 3a4bb9a69a1..8fbde19764d 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -1,3 +1,4 @@ +import json import logging import uuid from operator import itemgetter @@ -65,7 +66,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict: elif isinstance(message, (FunctionMessage, ToolMessage)): message_dict = { "role": "function", - "content": message.content, + "content": _create_tool_content(message.content), "name": message.name or message.additional_kwargs.get("name"), } else: @@ -74,6 +75,20 @@ def convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +def _create_tool_content(content: Union[str, List[Union[str, Dict[Any, Any]]]]) -> str: + """Convert tool content to dict scheme.""" + if isinstance(content, str): + try: + if isinstance(json.loads(content), dict): + return content + else: + return json.dumps({"tool_result": content}) + except json.JSONDecodeError: + return json.dumps({"tool_result": content}) + else: + return json.dumps({"tool_result": content}) + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: content = _dict.get("result", "") or "" additional_kwargs: Mapping[str, Any] = {} diff --git a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py index e69de29bb2d..b2d169f9c71 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py @@ -0,0 +1,39 @@ +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages.tool import ToolCall +from langchain_core.tools import tool + +from langchain_community.chat_models import QianfanChatEndpoint + + +@tool +def get_current_weather(location: str, unit: str = "摄氏度") -> str: + """获取指定地点的天气""" + return f"{location}是晴朗,25{unit}左右。" + + +def test_chat_qianfan_tool_result_to_model() -> None: + """Test QianfanChatEndpoint invoke with tool_calling result.""" + messages = [ + HumanMessage("上海天气怎么样?"), + AIMessage( + content=" ", + tool_calls=[ + ToolCall( + name="get_current_weather", + args={"location": "上海", "unit": "摄氏度"}, + id="foo", + type="tool_call", + ), + ], + ), + ToolMessage( + content="上海是晴天,25度左右。", + tool_call_id="foo", + name="get_current_weather", + ), + ] + chat = QianfanChatEndpoint(model="ERNIE-3.5-8K") # type: ignore[call-arg] + llm_with_tool = chat.bind_tools([get_current_weather]) + response = llm_with_tool.invoke(messages) + assert isinstance(response, AIMessage) + print(response.content) # noqa: T201