diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 169e6856ae9..c2d06f47fab 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -67,7 +67,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): class ToolCall(TypedDict): - """A call to a tool. + """Represents a request to call a tool. Attributes: name: (str) the name of the tool to be called diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index da1f638588d..f79bac5e283 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -44,7 +44,7 @@ def parse_tool_call( "args": function_args or {}, } if return_id: - parsed["id"] = raw_tool_call["id"] + parsed["id"] = raw_tool_call.get("id") return parsed @@ -67,9 +67,9 @@ def parse_tool_calls( partial: bool = False, strict: bool = False, return_id: bool = True, -) -> List[dict]: +) -> List[Dict[str, Any]]: """Parse a list of tool calls.""" - final_tools = [] + final_tools: List[Dict[str, Any]] = [] exceptions = [] for tool_call in raw_tool_calls: try: diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index c1d4c642e24..8b248ebb98e 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import uuid from operator import itemgetter from typing import ( Any, @@ -91,14 +92,18 @@ def _convert_mistral_chat_message_to_message( for raw_tool_call in raw_tool_calls: try: parsed: dict = cast( - dict, parse_tool_call(raw_tool_call, return_id=False) - ) - tool_calls.append( - { - **parsed, - **{"id": None}, - }, + dict, parse_tool_call(raw_tool_call, return_id=True) ) + if not parsed["id"]: + tool_call_id = uuid.uuid4().hex[:] + tool_calls.append( + { + **parsed, + **{"id": tool_call_id}, + }, + ) + else: + tool_calls.append(parsed) except Exception as e: invalid_tool_calls.append( dict(make_invalid_tool_call(raw_tool_call, str(e))) @@ -160,15 +165,20 @@ def _convert_delta_to_message_chunk( if raw_tool_calls := _delta.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls try: - tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc.get("index"), - } - for rtc in raw_tool_calls - ] + tool_call_chunks = [] + for raw_tool_call in raw_tool_calls: + if not raw_tool_call.get("index") and not raw_tool_call.get("id"): + tool_call_id = uuid.uuid4().hex[:] + else: + tool_call_id = raw_tool_call.get("id") + tool_call_chunks.append( + { + "name": raw_tool_call["function"].get("name"), + "args": raw_tool_call["function"].get("arguments"), + "id": tool_call_id, + "index": raw_tool_call.get("index"), + } + ) except KeyError: pass else: @@ -195,15 +205,17 @@ def _convert_message_to_mistral_chat_message( return dict(role="user", content=message.content) elif isinstance(message, AIMessage): if "tool_calls" in message.additional_kwargs: - tool_calls = [ - { + tool_calls = [] + for tc in message.additional_kwargs["tool_calls"]: + chunk = { "function": { "name": tc["function"]["name"], "arguments": tc["function"]["arguments"], } } - for tc in message.additional_kwargs["tool_calls"] - ] + if _id := tc.get("id"): + chunk["id"] = _id + tool_calls.append(chunk) else: tool_calls = [] return { diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index e5e78c91086..7dd19a4b9ce 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -7,8 +7,6 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, HumanMessage, - ToolCall, - ToolCallChunk, ) from langchain_core.pydantic_v1 import BaseModel @@ -168,9 +166,10 @@ def test_tool_call() -> None: result = tool_llm.invoke("Erick, 27 years old") assert isinstance(result, AIMessage) - assert result.tool_calls == [ - ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None) - ] + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call["name"] == "Person" + assert tool_call["args"] == {"name": "Erick", "age": 27} def test_streaming_tool_call() -> None: @@ -201,11 +200,10 @@ def test_streaming_tool_call() -> None: } assert isinstance(chunk, AIMessageChunk) - assert chunk.tool_call_chunks == [ - ToolCallChunk( - name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None - ) - ] + assert len(chunk.tool_call_chunks) == 1 + tool_call_chunk = chunk.tool_call_chunks[0] + assert tool_call_chunk["name"] == "Person" + assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}' # where it doesn't call the tool strm = tool_llm.stream("What is 2+2?") diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 18fca396bb7..96c637b5a2f 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -128,6 +128,7 @@ async def test_astream_with_callback() -> None: def test__convert_dict_to_message_tool_call() -> None: raw_tool_call = { + "id": "abc123", "function": { "arguments": '{"name":"Sally","hair_color":"green"}', "name": "GenerateUsername", @@ -142,7 +143,7 @@ def test__convert_dict_to_message_tool_call() -> None: ToolCall( name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, - id=None, + id="abc123", ) ], ) @@ -152,12 +153,14 @@ def test__convert_dict_to_message_tool_call() -> None: # Test malformed tool call raw_tool_calls = [ { + "id": "abc123", "function": { "arguments": "oops", "name": "GenerateUsername", }, }, { + "id": "def456", "function": { "arguments": '{"name":"Sally","hair_color":"green"}', "name": "GenerateUsername", @@ -174,14 +177,14 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args="oops", error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 - id=None, + id="abc123", ), ], tool_calls=[ ToolCall( name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, - id=None, + id="def456", ), ], )