mistral[patch]: add IDs to tool calls (#20299)

Mistral gives us one ID per response, no individual IDs for tool calls.

```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI


prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant"),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)

@tool
def magic_function(input: int) -> int:
    """Applies a magic function to an input."""
    return input + 2

tools = [magic_function]

agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
ccurme
2024-04-11 11:09:30 -04:00
committed by William Fu-Hinthorn
parent 269be353d6
commit e6ea42d7f0
5 changed files with 50 additions and 37 deletions

View File

@@ -67,7 +67,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
class ToolCall(TypedDict):
"""A call to a tool.
"""Represents a request to call a tool.
Attributes:
name: (str) the name of the tool to be called

View File

@@ -44,7 +44,7 @@ def parse_tool_call(
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call["id"]
parsed["id"] = raw_tool_call.get("id")
return parsed
@@ -67,9 +67,9 @@ def parse_tool_calls(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[dict]:
) -> List[Dict[str, Any]]:
"""Parse a list of tool calls."""
final_tools = []
final_tools: List[Dict[str, Any]] = []
exceptions = []
for tool_call in raw_tool_calls:
try:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import uuid
from operator import itemgetter
from typing import (
Any,
@@ -91,14 +92,18 @@ def _convert_mistral_chat_message_to_message(
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=False)
)
tool_calls.append(
{
**parsed,
**{"id": None},
},
dict, parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
tool_call_id = uuid.uuid4().hex[:]
tool_calls.append(
{
**parsed,
**{"id": tool_call_id},
},
)
else:
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
@@ -160,15 +165,20 @@ def _convert_delta_to_message_chunk(
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in raw_tool_calls
]
tool_call_chunks = []
for raw_tool_call in raw_tool_calls:
if not raw_tool_call.get("index") and not raw_tool_call.get("id"):
tool_call_id = uuid.uuid4().hex[:]
else:
tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append(
{
"name": raw_tool_call["function"].get("name"),
"args": raw_tool_call["function"].get("arguments"),
"id": tool_call_id,
"index": raw_tool_call.get("index"),
}
)
except KeyError:
pass
else:
@@ -195,15 +205,17 @@ def _convert_message_to_mistral_chat_message(
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = [
{
tool_calls = []
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
for tc in message.additional_kwargs["tool_calls"]
]
if _id := tc.get("id"):
chunk["id"] = _id
tool_calls.append(chunk)
else:
tool_calls = []
return {

View File

@@ -7,8 +7,6 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.pydantic_v1 import BaseModel
@@ -168,9 +166,10 @@ def test_tool_call() -> None:
result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
assert result.tool_calls == [
ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None)
]
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "Person"
assert tool_call["args"] == {"name": "Erick", "age": 27}
def test_streaming_tool_call() -> None:
@@ -201,11 +200,10 @@ def test_streaming_tool_call() -> None:
}
assert isinstance(chunk, AIMessageChunk)
assert chunk.tool_call_chunks == [
ToolCallChunk(
name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None
)
]
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "Person"
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'
# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")

View File

@@ -128,6 +128,7 @@ async def test_astream_with_callback() -> None:
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
@@ -142,7 +143,7 @@ def test__convert_dict_to_message_tool_call() -> None:
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="abc123",
)
],
)
@@ -152,12 +153,14 @@ def test__convert_dict_to_message_tool_call() -> None:
# Test malformed tool call
raw_tool_calls = [
{
"id": "abc123",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
},
{
"id": "def456",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
@@ -174,14 +177,14 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id=None,
id="abc123",
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="def456",
),
],
)