openai[patch]: Assign message id in ChatOpenAI (#17837)

This commit is contained in:
Nuno Campos 2024-02-27 17:32:54 -08:00 committed by GitHub
parent 733367b795
commit a99eb3abf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -92,8 +92,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
The LangChain message. The LangChain message.
""" """
role = _dict.get("role") role = _dict.get("role")
id_ = _dict.get("id")
if role == "user": if role == "user":
return HumanMessage(content=_dict.get("content", "")) return HumanMessage(content=_dict.get("content", ""), id=id_)
elif role == "assistant": elif role == "assistant":
# Fix for azure # Fix for azure
# Also OpenAI returns None for tool invocations # Also OpenAI returns None for tool invocations
@ -103,11 +104,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["function_call"] = dict(function_call) additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"): if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs) return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
elif role == "system": elif role == "system":
return SystemMessage(content=_dict.get("content", "")) return SystemMessage(content=_dict.get("content", ""), id=id_)
elif role == "function": elif role == "function":
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
)
elif role == "tool": elif role == "tool":
additional_kwargs = {} additional_kwargs = {}
if "name" in _dict: if "name" in _dict:
@ -116,9 +119,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
content=_dict.get("content", ""), content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"), tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
id=id_,
) )
else: else:
return ChatMessage(content=_dict.get("content", ""), role=role) return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(message: BaseMessage) -> dict:
@ -171,6 +175,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk( def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role")) role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "") content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
@ -183,19 +188,23 @@ def _convert_delta_to_message_chunk(
additional_kwargs["tool_calls"] = _dict["tool_calls"] additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) return AIMessageChunk(
content=content, additional_kwargs=additional_kwargs, id=id_
)
elif role == "system" or default_class == SystemMessageChunk: elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content, id=id_)
elif role == "function" or default_class == FunctionMessageChunk: elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"]) return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk: elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk: elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) return ChatMessageChunk(content=content, role=role, id=id_)
else: else:
return default_class(content=content) # type: ignore return default_class(content=content, id=id_) # type: ignore
class _FunctionCall(TypedDict): class _FunctionCall(TypedDict):