mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-12 09:56:27 +00:00
groq: read tool calls from .tool_calls attribute (#22096)
This commit is contained in:
parent
96c21dfe56
commit
0ea1e89b2c
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -44,8 +45,10 @@ from langchain_core.messages import (
|
|||||||
FunctionMessageChunk,
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
InvalidToolCall,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
@ -837,7 +840,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
# If function call only, content is None not empty string
|
# If function call only, content is None not empty string
|
||||||
if message_dict["content"] == "":
|
if message_dict["content"] == "":
|
||||||
message_dict["content"] = None
|
message_dict["content"] = None
|
||||||
if "tool_calls" in message.additional_kwargs:
|
if message.tool_calls or message.invalid_tool_calls:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
_lc_tool_call_to_groq_tool_call(tc) for tc in message.tool_calls
|
||||||
|
] + [
|
||||||
|
_lc_invalid_tool_call_to_groq_tool_call(tc)
|
||||||
|
for tc in message.invalid_tool_calls
|
||||||
|
]
|
||||||
|
elif "tool_calls" in message.additional_kwargs:
|
||||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
# If tool calls only, content is None not empty string
|
# If tool calls only, content is None not empty string
|
||||||
if message_dict["content"] == "":
|
if message_dict["content"] == "":
|
||||||
@ -944,3 +954,27 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ChatMessage(content=_dict.get("content", ""), role=role)
|
return ChatMessage(content=_dict.get("content", ""), role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_invalid_tool_call_to_groq_tool_call(
|
||||||
|
invalid_tool_call: InvalidToolCall,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": invalid_tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": invalid_tool_call["name"],
|
||||||
|
"arguments": invalid_tool_call["args"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user