mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
mistral: read tool calls from AIMessage (#20554)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
@@ -42,8 +43,10 @@ from langchain_core.messages import (
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
@@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
"""Format Langchain ToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
}
|
||||
}
|
||||
if _id := tool_call.get("id"):
|
||||
result["id"] = _id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
|
||||
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
}
|
||||
}
|
||||
if _id := invalid_tool_call.get("id"):
|
||||
result["id"] = _id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
@@ -231,8 +262,15 @@ def _convert_message_to_mistral_chat_message(
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
tool_calls = []
|
||||
tool_calls = []
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_calls.append(_format_tool_call_for_mistral(tool_call))
|
||||
for invalid_tool_call in message.invalid_tool_calls:
|
||||
tool_calls.append(
|
||||
_format_invalid_tool_call_for_mistral(invalid_tool_call)
|
||||
)
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
for tc in message.additional_kwargs["tool_calls"]:
|
||||
chunk = {
|
||||
"function": {
|
||||
@@ -244,7 +282,7 @@ def _convert_message_to_mistral_chat_message(
|
||||
chunk["id"] = _id
|
||||
tool_calls.append(chunk)
|
||||
else:
|
||||
tool_calls = []
|
||||
pass
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
|
||||
Reference in New Issue
Block a user