diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 8820897af1c..69234ce93d6 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import os import warnings from operator import itemgetter @@ -44,8 +45,10 @@ from langchain_core.messages import ( FunctionMessageChunk, HumanMessage, HumanMessageChunk, + InvalidToolCall, SystemMessage, SystemMessageChunk, + ToolCall, ToolMessage, ToolMessageChunk, ) @@ -837,7 +840,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: # If function call only, content is None not empty string if message_dict["content"] == "": message_dict["content"] = None - if "tool_calls" in message.additional_kwargs: + if message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_groq_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_groq_tool_call(tc) + for tc in message.invalid_tool_calls + ] + elif "tool_calls" in message.additional_kwargs: message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] # If tool calls only, content is None not empty string if message_dict["content"] == "": @@ -944,3 +954,27 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: ) else: return ChatMessage(content=_dict.get("content", ""), role=role) + + +def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _lc_invalid_tool_call_to_groq_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + }