mistral: read tool calls from AIMessage (#20554)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
ccurme
2024-04-17 13:38:24 -04:00
committed by GitHub
parent f257909699
commit 4a17951900
4 changed files with 56 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
import logging
import uuid
from operator import itemgetter
@@ -42,8 +43,10 @@ from langchain_core.messages import (
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
@@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
}
}
if _id := tool_call.get("id"):
result["id"] = _id
return result
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
}
}
if _id := invalid_tool_call.get("id"):
result["id"] = _id
return result
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> Dict:
@@ -231,8 +262,15 @@ def _convert_message_to_mistral_chat_message(
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = []
tool_calls = []
if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls:
tool_calls.append(_format_tool_call_for_mistral(tool_call))
for invalid_tool_call in message.invalid_tool_calls:
tool_calls.append(
_format_invalid_tool_call_for_mistral(invalid_tool_call)
)
elif "tool_calls" in message.additional_kwargs:
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
@@ -244,7 +282,7 @@ def _convert_message_to_mistral_chat_message(
chunk["id"] = _id
tool_calls.append(chunk)
else:
tool_calls = []
pass
return {
"role": "assistant",
"content": message.content,