This commit is contained in:
isaac hershenson 2024-11-01 16:43:57 -07:00
parent 830cad7bc0
commit cfd554bffc

View File

@ -317,9 +317,14 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
except KeyError:
msg_type = msg_kwargs.pop("type")
# None msg content is not allowed
msg_content = msg_kwargs.pop("content") or ""
content_or_tool_calls = (
"tool_calls" in msg_kwargs or "content" in msg_kwargs
)
if not content_or_tool_calls:
raise KeyError("Must have one of content or tool calls")
msg_content = msg_kwargs.pop("content", "") or ""
except KeyError as e:
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}"
msg = create_message(
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
)
@ -957,6 +962,10 @@ def convert_to_openai_messages(
oai_msg["name"] = message.name
if isinstance(message, AIMessage) and message.tool_calls:
oai_msg["tool_calls"] = _convert_to_openai_tool_calls(message.tool_calls)
if isinstance(message, AIMessage) and message.invalid_tool_calls:
oai_msg["tool_calls"] = oai_msg.get(
"tool_calls", []
) + _convert_to_openai_tool_calls(message.invalid_tool_calls)
if message.additional_kwargs.get("refusal"):
oai_msg["refusal"] = message.additional_kwargs["refusal"]
if isinstance(message, ToolMessage):
@ -1393,14 +1402,18 @@ def _get_message_openai_role(message: BaseMessage) -> str:
raise ValueError(msg)
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
def _convert_to_openai_tool_calls(
tool_calls: list[ToolCall], invalid=False
) -> list[dict]:
return [
{
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
"arguments": tool_call["args"]
if invalid
else json.dumps(tool_call["args"]),
},
}
for tool_call in tool_calls