From 7e740e5e1f2a3de4a6fb4012875b5f36835ff73b Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Fri, 11 Jul 2025 15:16:37 -0400 Subject: [PATCH] cr --- libs/core/langchain_core/messages/ai.py | 28 ++++--------------- .../langchain_core/messages/content_blocks.py | 19 ++++++++++++- libs/core/tests/unit_tests/test_messages.py | 13 +++++++++ 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index cdd8611fbe2..d386f8f7235 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -226,28 +226,12 @@ class AIMessage(BaseMessage): elif isinstance(item, dict): item_type = item.get("type") - if item_type == "text": - blocks.append(cast("types.TextContentBlock", item)) - elif item_type == "tool_call": - blocks.append(cast("types.ToolCallContentBlock", item)) - elif item_type == "reasoning": - blocks.append(cast("types.ReasoningContentBlock", item)) - elif item_type == "non_standard": - blocks.append(cast("types.NonStandardContentBlock", item)) - elif source_type := item.get("source_type"): - if source_type == "url": - blocks.append(cast("types.URLContentBlock", item)) - elif source_type == "base64": - blocks.append(cast("types.Base64ContentBlock", item)) - elif source_type == "text": - blocks.append(cast("types.PlainTextContentBlock", item)) - elif source_type == "id": - blocks.append(cast("types.IDContentBlock", item)) - else: - msg = f"Unknown source_type {source_type} in content block." - raise ValueError(msg) - else: - msg = f"Unknown content block type {item_type}." + if item_type not in types.KNOWN_BLOCK_TYPES: + msg = ( + f"Non-standard content block type '{item_type}'. Ensure " + "the model supports `output_version='v1'` or higher and " + "that this attribute is set on initialization." + ) raise ValueError(msg) else: pass diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py index 8da0a797226..c6145e92dcb 100644 --- a/libs/core/langchain_core/messages/content_blocks.py +++ b/libs/core/langchain_core/messages/content_blocks.py @@ -4,7 +4,7 @@ import warnings from typing import Any, Literal, Union from pydantic import TypeAdapter, ValidationError -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, get_args, get_origin # Text and annotations @@ -177,6 +177,23 @@ ContentBlock = Union[ ] +def _extract_typedict_type_values(union_type: Any) -> set[str]: + """Extract the values of the 'type' field from a TypedDict union type.""" + result: set[str] = set() + for value in get_args(union_type): + annotation = value.__annotations__["type"] + if get_origin(annotation) is Literal: + result.update(get_args(annotation)) + else: + msg = f"{value} 'type' is not a Literal" + raise ValueError(msg) + return result + + +# {"text", "tool_call", "reasoning", "non_standard", "image", "audio", "file"} +KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock) + + def is_data_content_block( content_block: dict, ) -> bool: diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index e6cc725cfca..dd04e506cd6 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -30,6 +30,7 @@ from langchain_core.messages import ( messages_from_dict, messages_to_dict, ) +from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk @@ -1197,3 +1198,15 @@ def test_convert_to_openai_image_block() -> None: } result = convert_to_openai_image_block(input_block) assert result == expected + + +def test_known_block_types() -> None: + assert { + "text", + "tool_call", + "reasoning", + "non_standard", + "image", + "audio", + "file", + } == KNOWN_BLOCK_TYPES