diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index a829641cce8..2f800319335 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -213,7 +213,11 @@ class AIMessage(BaseMessage): otherwise, does best-effort parsing to standard types. """ blocks: list[types.ContentBlock] = [] - content = [self.content] if isinstance(self.content, str) else self.content + content = ( + [self.content] + if isinstance(self.content, str) and self.content + else self.content + ) for item in content: if isinstance(item, str): blocks.append({"type": "text", "text": item}) @@ -227,6 +231,7 @@ class AIMessage(BaseMessage): "that this attribute is set on initialization." ) raise ValueError(msg) + blocks.append(cast("types.ContentBlock", item)) else: pass diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index d36d0347128..ef6a38750b8 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -1,5 +1,8 @@ +from typing import Union, cast + from langchain_core.load import dumpd, load from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages import content_blocks as types from langchain_core.messages.ai import ( InputTokenDetails, OutputTokenDetails, @@ -196,3 +199,71 @@ def test_add_ai_message_chunks_usage() -> None: output_token_details=OutputTokenDetails(audio=1, reasoning=2), ), ) + + +class ReasoningContentBlockWithID(types.ReasoningContentBlock): + id: str + + +def test_beta_content() -> None: + # Simple case + message = AIMessage("Hello") + assert len(message.beta_content) == 1 + assert message.beta_content[0]["type"] == "text" + for block in message.beta_content: + if block["type"] == "text": + text_block: types.TextContentBlock = block + assert text_block == {"type": "text", "text": "Hello"} + + # With tool calls + message = AIMessage( + "", + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"} + ], + ) + assert len(message.beta_content) == 1 + assert message.beta_content[0]["type"] == "tool_call" + for block in message.beta_content: + if block["type"] == "tool_call": + tool_call_block: types.ToolCallContentBlock = block + assert tool_call_block == {"type": "tool_call", "id": "abc_123"} + + # With standard blocks + reasoning_with_id: ReasoningContentBlockWithID = { + "type": "reasoning", + "reasoning": "foo", + "id": "rs_abc123", + } + standard_content: list[types.ContentBlock] = [ + {"type": "reasoning", "reasoning": "foo"}, + reasoning_with_id, + {"type": "text", "text": "bar"}, + { + "type": "text", + "text": "baz", + "annotations": [{"type": "url_citation", "url": "http://example.com"}], + }, + { + "type": "image", + "source_type": "url", + "url": "http://example.com/image.png", + }, + { + "type": "non_standard", + "value": {"custom_key": "custom_value", "another_key": 123}, + }, + { + "type": "tool_call", + "id": "abc_123", + }, + ] + message = AIMessage( + cast("list[Union[str, dict]]", standard_content), + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ], + ) + missing_tool_call = {"type": "tool_call", "id": "abc_234"} + assert message.beta_content == [*standard_content, missing_tool_call]