mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
openai[patch]: Assign message id in ChatOpenAI (#17837)
This commit is contained in:
parent
733367b795
commit
a99eb3abf4
@ -92,8 +92,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
The LangChain message.
|
The LangChain message.
|
||||||
"""
|
"""
|
||||||
role = _dict.get("role")
|
role = _dict.get("role")
|
||||||
|
id_ = _dict.get("id")
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=_dict.get("content", ""))
|
return HumanMessage(content=_dict.get("content", ""), id=id_)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
# Fix for azure
|
# Fix for azure
|
||||||
# Also OpenAI returns None for tool invocations
|
# Also OpenAI returns None for tool invocations
|
||||||
@ -103,11 +104,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["function_call"] = dict(function_call)
|
additional_kwargs["function_call"] = dict(function_call)
|
||||||
if tool_calls := _dict.get("tool_calls"):
|
if tool_calls := _dict.get("tool_calls"):
|
||||||
additional_kwargs["tool_calls"] = tool_calls
|
additional_kwargs["tool_calls"] = tool_calls
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return SystemMessage(content=_dict.get("content", ""))
|
return SystemMessage(content=_dict.get("content", ""), id=id_)
|
||||||
elif role == "function":
|
elif role == "function":
|
||||||
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name"))
|
return FunctionMessage(
|
||||||
|
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
|
||||||
|
)
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if "name" in _dict:
|
if "name" in _dict:
|
||||||
@ -116,9 +119,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
content=_dict.get("content", ""),
|
content=_dict.get("content", ""),
|
||||||
tool_call_id=_dict.get("tool_call_id"),
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
|
id=id_,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ChatMessage(content=_dict.get("content", ""), role=role)
|
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
@ -171,6 +175,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
def _convert_delta_to_message_chunk(
|
def _convert_delta_to_message_chunk(
|
||||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
|
id_ = _dict.get("id")
|
||||||
role = cast(str, _dict.get("role"))
|
role = cast(str, _dict.get("role"))
|
||||||
content = cast(str, _dict.get("content") or "")
|
content = cast(str, _dict.get("content") or "")
|
||||||
additional_kwargs: Dict = {}
|
additional_kwargs: Dict = {}
|
||||||
@ -183,19 +188,23 @@ def _convert_delta_to_message_chunk(
|
|||||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content, id=id_)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
return AIMessageChunk(
|
||||||
|
content=content, additional_kwargs=additional_kwargs, id=id_
|
||||||
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content, id=id_)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||||
elif role == "tool" or default_class == ToolMessageChunk:
|
elif role == "tool" or default_class == ToolMessageChunk:
|
||||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
return ToolMessageChunk(
|
||||||
|
content=content, tool_call_id=_dict["tool_call_id"], id=id_
|
||||||
|
)
|
||||||
elif role or default_class == ChatMessageChunk:
|
elif role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role)
|
return ChatMessageChunk(content=content, role=role, id=id_)
|
||||||
else:
|
else:
|
||||||
return default_class(content=content) # type: ignore
|
return default_class(content=content, id=id_) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class _FunctionCall(TypedDict):
|
class _FunctionCall(TypedDict):
|
||||||
|
Loading…
Reference in New Issue
Block a user