diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 9d4a1e01879..b38a7fa1a49 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -203,6 +203,13 @@ class AIMessage(BaseMessage): ) -> None: """Specify content as a positional arg or content_blocks for typing support.""" if content_blocks is not None: + # If there are tool calls in content_blocks, but not in tool_calls, add them + content_tool_calls = [ + block for block in content_blocks if block.get("type") == "tool_call" + ] + if content_tool_calls and "tool_calls" not in kwargs: + kwargs["tool_calls"] = content_tool_calls + super().__init__( content=cast("Union[str, list[Union[str, dict]]]", content_blocks), **kwargs, @@ -273,7 +280,9 @@ class AIMessage(BaseMessage): # Ensure "type" is properly set on all tool call-like dicts. if tool_calls := values.get("tool_calls"): values["tool_calls"] = [ - create_tool_call(**{k: v for k, v in tc.items() if k != "type"}) + create_tool_call( + **{k: v for k, v in tc.items() if k not in ("type", "extras")} + ) for tc in tool_calls ] if invalid_tool_calls := values.get("invalid_tool_calls"): diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index 81981725c50..a7225015c2c 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -253,7 +253,7 @@ def test_content_blocks() -> None: "id": "abc_123", }, ] - missing_tool_call = { + missing_tool_call: types.ToolCall = { "type": "tool_call", "name": "bar", "args": {"c": "d"}, @@ -267,3 +267,20 @@ def test_content_blocks() -> None: ], ) assert message.content_blocks == [*standard_content, missing_tool_call] + + # Check we auto-populate tool_calls + standard_content = [ + {"type": "text", "text": "foo"}, + { + "type": "tool_call", + "name": "foo", + "args": {"a": "b"}, + "id": "abc_123", + }, + missing_tool_call, + ] + message = AIMessage(content_blocks=standard_content) + assert message.tool_calls == [ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + missing_tool_call, + ]