mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
William Fu-Hinthorn
parent
269be353d6
commit
e6ea42d7f0
@@ -67,7 +67,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""A call to a tool.
|
||||
"""Represents a request to call a tool.
|
||||
|
||||
Attributes:
|
||||
name: (str) the name of the tool to be called
|
||||
|
||||
@@ -44,7 +44,7 @@ def parse_tool_call(
|
||||
"args": function_args or {},
|
||||
}
|
||||
if return_id:
|
||||
parsed["id"] = raw_tool_call["id"]
|
||||
parsed["id"] = raw_tool_call.get("id")
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -67,9 +67,9 @@ def parse_tool_calls(
|
||||
partial: bool = False,
|
||||
strict: bool = False,
|
||||
return_id: bool = True,
|
||||
) -> List[dict]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Parse a list of tool calls."""
|
||||
final_tools = []
|
||||
final_tools: List[Dict[str, Any]] = []
|
||||
exceptions = []
|
||||
for tool_call in raw_tool_calls:
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -91,14 +92,18 @@ def _convert_mistral_chat_message_to_message(
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
parsed: dict = cast(
|
||||
dict, parse_tool_call(raw_tool_call, return_id=False)
|
||||
)
|
||||
tool_calls.append(
|
||||
{
|
||||
**parsed,
|
||||
**{"id": None},
|
||||
},
|
||||
dict, parse_tool_call(raw_tool_call, return_id=True)
|
||||
)
|
||||
if not parsed["id"]:
|
||||
tool_call_id = uuid.uuid4().hex[:]
|
||||
tool_calls.append(
|
||||
{
|
||||
**parsed,
|
||||
**{"id": tool_call_id},
|
||||
},
|
||||
)
|
||||
else:
|
||||
tool_calls.append(parsed)
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
dict(make_invalid_tool_call(raw_tool_call, str(e)))
|
||||
@@ -160,15 +165,20 @@ def _convert_delta_to_message_chunk(
|
||||
if raw_tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc.get("index"),
|
||||
}
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
tool_call_chunks = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not raw_tool_call.get("index") and not raw_tool_call.get("id"):
|
||||
tool_call_id = uuid.uuid4().hex[:]
|
||||
else:
|
||||
tool_call_id = raw_tool_call.get("id")
|
||||
tool_call_chunks.append(
|
||||
{
|
||||
"name": raw_tool_call["function"].get("name"),
|
||||
"args": raw_tool_call["function"].get("arguments"),
|
||||
"id": tool_call_id,
|
||||
"index": raw_tool_call.get("index"),
|
||||
}
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@@ -195,15 +205,17 @@ def _convert_message_to_mistral_chat_message(
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
tool_calls = [
|
||||
{
|
||||
tool_calls = []
|
||||
for tc in message.additional_kwargs["tool_calls"]:
|
||||
chunk = {
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
}
|
||||
for tc in message.additional_kwargs["tool_calls"]
|
||||
]
|
||||
if _id := tc.get("id"):
|
||||
chunk["id"] = _id
|
||||
tool_calls.append(chunk)
|
||||
else:
|
||||
tool_calls = []
|
||||
return {
|
||||
|
||||
@@ -7,8 +7,6 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
@@ -168,9 +166,10 @@ def test_tool_call() -> None:
|
||||
|
||||
result = tool_llm.invoke("Erick, 27 years old")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.tool_calls == [
|
||||
ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None)
|
||||
]
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call["name"] == "Person"
|
||||
assert tool_call["args"] == {"name": "Erick", "age": 27}
|
||||
|
||||
|
||||
def test_streaming_tool_call() -> None:
|
||||
@@ -201,11 +200,10 @@ def test_streaming_tool_call() -> None:
|
||||
}
|
||||
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert chunk.tool_call_chunks == [
|
||||
ToolCallChunk(
|
||||
name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None
|
||||
)
|
||||
]
|
||||
assert len(chunk.tool_call_chunks) == 1
|
||||
tool_call_chunk = chunk.tool_call_chunks[0]
|
||||
assert tool_call_chunk["name"] == "Person"
|
||||
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'
|
||||
|
||||
# where it doesn't call the tool
|
||||
strm = tool_llm.stream("What is 2+2?")
|
||||
|
||||
@@ -128,6 +128,7 @@ async def test_astream_with_callback() -> None:
|
||||
|
||||
def test__convert_dict_to_message_tool_call() -> None:
|
||||
raw_tool_call = {
|
||||
"id": "abc123",
|
||||
"function": {
|
||||
"arguments": '{"name":"Sally","hair_color":"green"}',
|
||||
"name": "GenerateUsername",
|
||||
@@ -142,7 +143,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
ToolCall(
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id=None,
|
||||
id="abc123",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -152,12 +153,14 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
# Test malformed tool call
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "abc123",
|
||||
"function": {
|
||||
"arguments": "oops",
|
||||
"name": "GenerateUsername",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "def456",
|
||||
"function": {
|
||||
"arguments": '{"name":"Sally","hair_color":"green"}',
|
||||
"name": "GenerateUsername",
|
||||
@@ -174,14 +177,14 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args="oops",
|
||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||
id=None,
|
||||
id="abc123",
|
||||
),
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id=None,
|
||||
id="def456",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user