mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
fireworks[patch]: read from tool calls attribute (#23820)
This commit is contained in:
parent
e787249af1
commit
54e730f6e4
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from operator import itemgetter
|
||||
@ -46,8 +47,10 @@ from langchain_core.messages import (
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
@ -153,11 +156,20 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_fireworks_tool_call(tc) for tc in message.tool_calls
|
||||
] + [
|
||||
_lc_invalid_tool_call_to_fireworks_tool_call(tc)
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
# If tool calls only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
if "tool_calls" in message_dict and message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
else:
|
||||
pass
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
@ -922,3 +934,27 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _lc_invalid_tool_call_to_fireworks_tool_call(
|
||||
invalid_tool_call: InvalidToolCall,
|
||||
) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": invalid_tool_call["id"],
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
},
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user