Compare commits

...

5 Commits

Author SHA1 Message Date
isaac hershenson
994bde53e3 fmt 2024-11-05 07:39:42 -08:00
isaac hershenson
df415417a1 fmt 2024-11-05 07:36:36 -08:00
isaac hershenson
85a1215217 fmt 2024-11-05 07:30:54 -08:00
isaac hershenson
1160090ce3 fix 2024-11-01 16:45:56 -07:00
isaac hershenson
cfd554bffc wip 2024-11-01 16:44:57 -07:00

View File

@@ -36,7 +36,12 @@ from langchain_core.messages.function import FunctionMessage, FunctionMessageChu
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.modifier import RemoveMessage
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolMessage,
ToolMessageChunk,
)
if TYPE_CHECKING:
from langchain_text_splitters import TextSplitter
@@ -317,9 +322,15 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
except KeyError:
msg_type = msg_kwargs.pop("type")
# None msg content is not allowed
msg_content = msg_kwargs.pop("content") or ""
content_or_tool_calls = (
"tool_calls" in msg_kwargs or "content" in msg_kwargs
)
if not content_or_tool_calls:
msg = "Must have one of content or tool calls"
raise KeyError(msg)
msg_content = msg_kwargs.pop("content", "") or ""
except KeyError as e:
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}" # noqa: E501
msg = create_message(
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
)
@@ -957,6 +968,10 @@ def convert_to_openai_messages(
oai_msg["name"] = message.name
if isinstance(message, AIMessage) and message.tool_calls:
oai_msg["tool_calls"] = _convert_to_openai_tool_calls(message.tool_calls)
if isinstance(message, AIMessage) and message.invalid_tool_calls:
oai_msg["tool_calls"] = oai_msg.get(
"tool_calls", []
) + _convert_to_openai_tool_calls(message.invalid_tool_calls, invalid=True)
if message.additional_kwargs.get("refusal"):
oai_msg["refusal"] = message.additional_kwargs["refusal"]
if isinstance(message, ToolMessage):
@@ -1393,14 +1408,18 @@ def _get_message_openai_role(message: BaseMessage) -> str:
raise ValueError(msg)
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
def _convert_to_openai_tool_calls(
tool_calls: Union[list[ToolCall], list[InvalidToolCall]], invalid: bool = False
) -> list[dict]:
return [
{
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
"arguments": tool_call["args"]
if invalid
else json.dumps(tool_call["args"]),
},
}
for tool_call in tool_calls