fireworks[patch]: read from tool calls attribute (#23820)

This commit is contained in:
ccurme 2024-07-03 11:11:17 -04:00 committed by GitHub
parent e787249af1
commit 54e730f6e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os import os
from operator import itemgetter from operator import itemgetter
@ -46,8 +47,10 @@ from langchain_core.messages import (
FunctionMessageChunk, FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
InvalidToolCall,
SystemMessage, SystemMessage,
SystemMessageChunk, SystemMessageChunk,
ToolCall,
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
) )
@ -153,11 +156,20 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
# If function call only, content is None not empty string # If function call only, content is None not empty string
if message_dict["content"] == "": if message_dict["content"] == "":
message_dict["content"] = None message_dict["content"] = None
if "tool_calls" in message.additional_kwargs: if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_fireworks_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_fireworks_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string # If tool calls only, content is None not empty string
if message_dict["content"] == "": if "tool_calls" in message_dict and message_dict["content"] == "":
message_dict["content"] = None message_dict["content"] = None
else:
pass
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
@ -922,3 +934,27 @@ class ChatFireworks(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and issubclass(obj, BaseModel)
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_fireworks_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}