mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
mistral: read tool calls from AIMessage (#20554)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@@ -42,8 +43,10 @@ from langchain_core.messages import (
|
|||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
InvalidToolCall,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
@@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content)
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||||
|
"""Format Langchain ToolCall to dict expected by Mistral."""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _id := tool_call.get("id"):
|
||||||
|
result["id"] = _id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
|
||||||
|
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"function": {
|
||||||
|
"name": invalid_tool_call["name"],
|
||||||
|
"arguments": invalid_tool_call["args"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _id := invalid_tool_call.get("id"):
|
||||||
|
result["id"] = _id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_mistral_chat_message(
|
def _convert_message_to_mistral_chat_message(
|
||||||
message: BaseMessage,
|
message: BaseMessage,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -231,8 +262,15 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
return dict(role="user", content=message.content)
|
return dict(role="user", content=message.content)
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
if "tool_calls" in message.additional_kwargs:
|
tool_calls = []
|
||||||
tool_calls = []
|
if message.tool_calls or message.invalid_tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
tool_calls.append(_format_tool_call_for_mistral(tool_call))
|
||||||
|
for invalid_tool_call in message.invalid_tool_calls:
|
||||||
|
tool_calls.append(
|
||||||
|
_format_invalid_tool_call_for_mistral(invalid_tool_call)
|
||||||
|
)
|
||||||
|
elif "tool_calls" in message.additional_kwargs:
|
||||||
for tc in message.additional_kwargs["tool_calls"]:
|
for tc in message.additional_kwargs["tool_calls"]:
|
||||||
chunk = {
|
chunk = {
|
||||||
"function": {
|
"function": {
|
||||||
@@ -244,7 +282,7 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
chunk["id"] = _id
|
chunk["id"] = _id
|
||||||
tool_calls.append(chunk)
|
tool_calls.append(chunk)
|
||||||
else:
|
else:
|
||||||
tool_calls = []
|
pass
|
||||||
return {
|
return {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
|
@@ -138,7 +138,7 @@ def test_structured_output() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_streaming_structured_output() -> None:
|
def test_streaming_structured_output() -> None:
|
||||||
llm = ChatMistralAI(model="mistral-large", temperature=0)
|
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -156,7 +156,7 @@ def test_streaming_structured_output() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tool_call() -> None:
|
def test_tool_call() -> None:
|
||||||
llm = ChatMistralAI(model="mistral-large", temperature=0)
|
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -173,7 +173,7 @@ def test_tool_call() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_streaming_tool_call() -> None:
|
def test_streaming_tool_call() -> None:
|
||||||
llm = ChatMistralAI(model="mistral-large", temperature=0)
|
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@@ -13,3 +13,10 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
return ChatMistralAI
|
return ChatMistralAI
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chat_model_params(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "mistral-large-latest",
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
@@ -130,7 +130,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
raw_tool_call = {
|
raw_tool_call = {
|
||||||
"id": "abc123",
|
"id": "abc123",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": '{"name":"Sally","hair_color":"green"}',
|
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
||||||
"name": "GenerateUsername",
|
"name": "GenerateUsername",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -153,16 +153,16 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
# Test malformed tool call
|
# Test malformed tool call
|
||||||
raw_tool_calls = [
|
raw_tool_calls = [
|
||||||
{
|
{
|
||||||
"id": "abc123",
|
"id": "def456",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": "oops",
|
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
||||||
"name": "GenerateUsername",
|
"name": "GenerateUsername",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "def456",
|
"id": "abc123",
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": '{"name":"Sally","hair_color":"green"}',
|
"arguments": "oops",
|
||||||
"name": "GenerateUsername",
|
"name": "GenerateUsername",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Reference in New Issue
Block a user