This commit is contained in:
Chester Curme 2025-07-11 15:16:37 -04:00
parent 7ab615409c
commit 7e740e5e1f
3 changed files with 37 additions and 23 deletions

View File

@ -226,28 +226,12 @@ class AIMessage(BaseMessage):
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
blocks.append(cast("types.TextContentBlock", item))
elif item_type == "tool_call":
blocks.append(cast("types.ToolCallContentBlock", item))
elif item_type == "reasoning":
blocks.append(cast("types.ReasoningContentBlock", item))
elif item_type == "non_standard":
blocks.append(cast("types.NonStandardContentBlock", item))
elif source_type := item.get("source_type"):
if source_type == "url":
blocks.append(cast("types.URLContentBlock", item))
elif source_type == "base64":
blocks.append(cast("types.Base64ContentBlock", item))
elif source_type == "text":
blocks.append(cast("types.PlainTextContentBlock", item))
elif source_type == "id":
blocks.append(cast("types.IDContentBlock", item))
else:
msg = f"Unknown source_type {source_type} in content block."
raise ValueError(msg)
else:
msg = f"Unknown content block type {item_type}."
if item_type not in types.KNOWN_BLOCK_TYPES:
msg = (
f"Non-standard content block type '{item_type}'. Ensure "
"the model supports `output_version='v1'` or higher and "
"that this attribute is set on initialization."
)
raise ValueError(msg)
else:
pass

View File

@ -4,7 +4,7 @@ import warnings
from typing import Any, Literal, Union
from pydantic import TypeAdapter, ValidationError
from typing_extensions import NotRequired, TypedDict
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
# Text and annotations
@ -177,6 +177,23 @@ ContentBlock = Union[
]
def _extract_typedict_type_values(union_type: Any) -> set[str]:
"""Extract the values of the 'type' field from a TypedDict union type."""
result: set[str] = set()
for value in get_args(union_type):
annotation = value.__annotations__["type"]
if get_origin(annotation) is Literal:
result.update(get_args(annotation))
else:
msg = f"{value} 'type' is not a Literal"
raise ValueError(msg)
return result
# {"text", "tool_call", "reasoning", "non_standard", "image", "audio", "file"}
KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock)
def is_data_content_block(
content_block: dict,
) -> bool:

View File

@ -30,6 +30,7 @@ from langchain_core.messages import (
messages_from_dict,
messages_to_dict,
)
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
@ -1197,3 +1198,15 @@ def test_convert_to_openai_image_block() -> None:
}
result = convert_to_openai_image_block(input_block)
assert result == expected
def test_known_block_types() -> None:
assert {
"text",
"tool_call",
"reasoning",
"non_standard",
"image",
"audio",
"file",
} == KNOWN_BLOCK_TYPES