From 10b12e1c08925b593285b55c76c48848c9a242dc Mon Sep 17 00:00:00 2001 From: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Date: Thu, 30 May 2024 22:59:08 +0800 Subject: [PATCH] community: adding tool_call_id for every ToolCall (#22323) - **Description:** This PR contains a bugfix which result in malfunction of multi-turn conversation in QianfanChatEndpoint and adaption for ToolCall and ToolMessage --- .../chat_models/baidu_qianfan_endpoint.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 95b5fc16390..0305c816f14 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -1,4 +1,5 @@ import logging +import uuid from operator import itemgetter from typing import ( Any, @@ -29,6 +30,7 @@ from langchain_core.messages import ( FunctionMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -59,7 +61,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict: # If function call only, content is None not empty string if message_dict["content"] == "": message_dict["content"] = None - elif isinstance(message, FunctionMessage): + elif isinstance(message, (FunctionMessage, ToolMessage)): message_dict = { "role": "function", "content": message.content, @@ -81,21 +83,28 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: additional_kwargs["function_call"].pop("thoughts") additional_kwargs = {**_dict.get("body", {}), **additional_kwargs} + msg_additional_kwargs = dict( + finish_reason=additional_kwargs.get("finish_reason", ""), + request_id=additional_kwargs["id"], + object=additional_kwargs.get("object", ""), + search_info=additional_kwargs.get("search_info", []), + ) + + if additional_kwargs.get("function_call", {}): + msg_additional_kwargs["function_call"] = additional_kwargs.get( + "function_call", {} + ) + msg_additional_kwargs["tool_calls"] = [ + { + "type": "function", + "function": additional_kwargs.get("function_call", {}), + "id": str(uuid.uuid4()), + } + ] + return AIMessage( content=content, - additional_kwargs=dict( - finish_reason=additional_kwargs.get("finish_reason", ""), - request_id=additional_kwargs["id"], - object=additional_kwargs.get("object", ""), - search_info=additional_kwargs.get("search_info", []), - function_call=additional_kwargs.get("function_call", {}), - tool_calls=[ - { - "type": "function", - "function": additional_kwargs.get("function_call", {}), - } - ], - ), + additional_kwargs=msg_additional_kwargs, )