From cfd554bffce6a2398bd54100f5dc85bf604402b8 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Fri, 1 Nov 2024 16:43:57 -0700 Subject: [PATCH] wip --- libs/core/langchain_core/messages/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 0ea03e40d00..14120caec59 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -317,9 +317,14 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: except KeyError: msg_type = msg_kwargs.pop("type") # None msg content is not allowed - msg_content = msg_kwargs.pop("content") or "" + content_or_tool_calls = ( + "tool_calls" in msg_kwargs or "content" in msg_kwargs + ) + if not content_or_tool_calls: + raise KeyError("Must have one of content or tool calls") + msg_content = msg_kwargs.pop("content", "") or "" except KeyError as e: - msg = f"Message dict must contain 'role' and 'content' keys, got {message}" + msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}" msg = create_message( message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE ) @@ -957,6 +962,10 @@ def convert_to_openai_messages( oai_msg["name"] = message.name if isinstance(message, AIMessage) and message.tool_calls: oai_msg["tool_calls"] = _convert_to_openai_tool_calls(message.tool_calls) + if isinstance(message, AIMessage) and message.invalid_tool_calls: + oai_msg["tool_calls"] = oai_msg.get( + "tool_calls", [] + ) + _convert_to_openai_tool_calls(message.invalid_tool_calls) if message.additional_kwargs.get("refusal"): oai_msg["refusal"] = message.additional_kwargs["refusal"] if isinstance(message, ToolMessage): @@ -1393,14 +1402,18 @@ def _get_message_openai_role(message: BaseMessage) -> str: raise ValueError(msg) -def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]: +def _convert_to_openai_tool_calls( + tool_calls: list[ToolCall], invalid=False +) -> list[dict]: return [ { "type": "function", "id": tool_call["id"], "function": { "name": tool_call["name"], - "arguments": json.dumps(tool_call["args"]), + "arguments": tool_call["args"] + if invalid + else json.dumps(tool_call["args"]), }, } for tool_call in tool_calls