From 54e730f6e4afafb4a667fbc7ff689b6924e727c1 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 3 Jul 2024 11:11:17 -0400 Subject: [PATCH] fireworks[patch]: read from tool calls attribute (#23820) --- .../langchain_fireworks/chat_models.py | 44 +++++++++++++++++-- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 7eb7ae70f9f..a3194c9a841 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging import os from operator import itemgetter @@ -46,8 +47,10 @@ from langchain_core.messages import ( FunctionMessageChunk, HumanMessage, HumanMessageChunk, + InvalidToolCall, SystemMessage, SystemMessageChunk, + ToolCall, ToolMessage, ToolMessageChunk, ) @@ -153,11 +156,20 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: # If function call only, content is None not empty string if message_dict["content"] == "": message_dict["content"] = None - if "tool_calls" in message.additional_kwargs: + if message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_fireworks_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_fireworks_tool_call(tc) + for tc in message.invalid_tool_calls + ] + elif "tool_calls" in message.additional_kwargs: message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] - # If tool calls only, content is None not empty string - if message_dict["content"] == "": - message_dict["content"] = None + # If tool calls only, content is None not empty string + if "tool_calls" in message_dict and message_dict["content"] == "": + message_dict["content"] = None + else: + pass elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, FunctionMessage): @@ -922,3 +934,27 @@ class ChatFireworks(BaseChatModel): def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) + + +def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _lc_invalid_tool_call_to_fireworks_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + }