community: adding tool_call_id for every ToolCall (#22323)

- **Description:** This PR contains a bugfix which result in malfunction
of multi-turn conversation in QianfanChatEndpoint and adaption for
ToolCall and ToolMessage
This commit is contained in:
Dobiichi-Origami 2024-05-30 22:59:08 +08:00 committed by GitHub
parent 569d325a59
commit 10b12e1c08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
import logging
import uuid
from operator import itemgetter
from typing import (
Any,
@ -29,6 +30,7 @@ from langchain_core.messages import (
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
@ -59,7 +61,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, FunctionMessage):
elif isinstance(message, (FunctionMessage, ToolMessage)):
message_dict = {
"role": "function",
"content": message.content,
@ -81,21 +83,28 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
additional_kwargs["function_call"].pop("thoughts")
additional_kwargs = {**_dict.get("body", {}), **additional_kwargs}
msg_additional_kwargs = dict(
finish_reason=additional_kwargs.get("finish_reason", ""),
request_id=additional_kwargs["id"],
object=additional_kwargs.get("object", ""),
search_info=additional_kwargs.get("search_info", []),
)
if additional_kwargs.get("function_call", {}):
msg_additional_kwargs["function_call"] = additional_kwargs.get(
"function_call", {}
)
msg_additional_kwargs["tool_calls"] = [
{
"type": "function",
"function": additional_kwargs.get("function_call", {}),
"id": str(uuid.uuid4()),
}
]
return AIMessage(
content=content,
additional_kwargs=dict(
finish_reason=additional_kwargs.get("finish_reason", ""),
request_id=additional_kwargs["id"],
object=additional_kwargs.get("object", ""),
search_info=additional_kwargs.get("search_info", []),
function_call=additional_kwargs.get("function_call", {}),
tool_calls=[
{
"type": "function",
"function": additional_kwargs.get("function_call", {}),
}
],
),
additional_kwargs=msg_additional_kwargs,
)