diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 3a4bb9a69a1..8fbde19764d 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -1,3 +1,4 @@ +import json import logging import uuid from operator import itemgetter @@ -65,7 +66,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict: elif isinstance(message, (FunctionMessage, ToolMessage)): message_dict = { "role": "function", - "content": message.content, + "content": _create_tool_content(message.content), "name": message.name or message.additional_kwargs.get("name"), } else: @@ -74,6 +75,20 @@ def convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +def _create_tool_content(content: Union[str, List[Union[str, Dict[Any, Any]]]]) -> str: + """Convert tool content to dict scheme.""" + if isinstance(content, str): + try: + if isinstance(json.loads(content), dict): + return content + else: + return json.dumps({"tool_result": content}) + except json.JSONDecodeError: + return json.dumps({"tool_result": content}) + else: + return json.dumps({"tool_result": content}) + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: content = _dict.get("result", "") or "" additional_kwargs: Mapping[str, Any] = {} diff --git a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py index e69de29bb2d..b2d169f9c71 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py @@ -0,0 +1,39 @@ +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages.tool import ToolCall +from langchain_core.tools import tool + +from langchain_community.chat_models import QianfanChatEndpoint + + +@tool +def get_current_weather(location: str, unit: str = "摄氏度") -> str: + """获取指定地点的天气""" + return f"{location}是晴朗,25{unit}左右。" + + +def test_chat_qianfan_tool_result_to_model() -> None: + """Test QianfanChatEndpoint invoke with tool_calling result.""" + messages = [ + HumanMessage("上海天气怎么样?"), + AIMessage( + content=" ", + tool_calls=[ + ToolCall( + name="get_current_weather", + args={"location": "上海", "unit": "摄氏度"}, + id="foo", + type="tool_call", + ), + ], + ), + ToolMessage( + content="上海是晴天,25度左右。", + tool_call_id="foo", + name="get_current_weather", + ), + ] + chat = QianfanChatEndpoint(model="ERNIE-3.5-8K") # type: ignore[call-arg] + llm_with_tool = chat.bind_tools([get_current_weather]) + response = llm_with_tool.invoke(messages) + assert isinstance(response, AIMessage) + print(response.content) # noqa: T201