mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 03:56:39 +00:00
cr
This commit is contained in:
parent
7ab615409c
commit
7e740e5e1f
@ -226,28 +226,12 @@ class AIMessage(BaseMessage):
|
||||
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
blocks.append(cast("types.TextContentBlock", item))
|
||||
elif item_type == "tool_call":
|
||||
blocks.append(cast("types.ToolCallContentBlock", item))
|
||||
elif item_type == "reasoning":
|
||||
blocks.append(cast("types.ReasoningContentBlock", item))
|
||||
elif item_type == "non_standard":
|
||||
blocks.append(cast("types.NonStandardContentBlock", item))
|
||||
elif source_type := item.get("source_type"):
|
||||
if source_type == "url":
|
||||
blocks.append(cast("types.URLContentBlock", item))
|
||||
elif source_type == "base64":
|
||||
blocks.append(cast("types.Base64ContentBlock", item))
|
||||
elif source_type == "text":
|
||||
blocks.append(cast("types.PlainTextContentBlock", item))
|
||||
elif source_type == "id":
|
||||
blocks.append(cast("types.IDContentBlock", item))
|
||||
else:
|
||||
msg = f"Unknown source_type {source_type} in content block."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unknown content block type {item_type}."
|
||||
if item_type not in types.KNOWN_BLOCK_TYPES:
|
||||
msg = (
|
||||
f"Non-standard content block type '{item_type}'. Ensure "
|
||||
"the model supports `output_version='v1'` or higher and "
|
||||
"that this attribute is set on initialization."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
pass
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict, get_args, get_origin
|
||||
|
||||
|
||||
# Text and annotations
|
||||
@ -177,6 +177,23 @@ ContentBlock = Union[
|
||||
]
|
||||
|
||||
|
||||
def _extract_typedict_type_values(union_type: Any) -> set[str]:
|
||||
"""Extract the values of the 'type' field from a TypedDict union type."""
|
||||
result: set[str] = set()
|
||||
for value in get_args(union_type):
|
||||
annotation = value.__annotations__["type"]
|
||||
if get_origin(annotation) is Literal:
|
||||
result.update(get_args(annotation))
|
||||
else:
|
||||
msg = f"{value} 'type' is not a Literal"
|
||||
raise ValueError(msg)
|
||||
return result
|
||||
|
||||
|
||||
# {"text", "tool_call", "reasoning", "non_standard", "image", "audio", "file"}
|
||||
KNOWN_BLOCK_TYPES = _extract_typedict_type_values(ContentBlock)
|
||||
|
||||
|
||||
def is_data_content_block(
|
||||
content_block: dict,
|
||||
) -> bool:
|
||||
|
@ -30,6 +30,7 @@ from langchain_core.messages import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
@ -1197,3 +1198,15 @@ def test_convert_to_openai_image_block() -> None:
|
||||
}
|
||||
result = convert_to_openai_image_block(input_block)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_known_block_types() -> None:
|
||||
assert {
|
||||
"text",
|
||||
"tool_call",
|
||||
"reasoning",
|
||||
"non_standard",
|
||||
"image",
|
||||
"audio",
|
||||
"file",
|
||||
} == KNOWN_BLOCK_TYPES
|
||||
|
Loading…
Reference in New Issue
Block a user