mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 12:06:43 +00:00
implement beta_content
This commit is contained in:
parent
26038608a4
commit
679a9e7c8f
@ -8,6 +8,7 @@ from typing import Any, Literal, Optional, Union, cast
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict, override
|
||||
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
@ -196,6 +197,79 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@property
|
||||
def beta_content(self) -> list[types.ContentBlock]:
|
||||
"""Return the content as a list of standard ContentBlocks.
|
||||
|
||||
To use this property, the corresponding chat model must support
|
||||
``output_version="v1"`` or higher:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
llm = init_chat_model("...", output_version="v1")
|
||||
|
||||
otherwise, does best-effort parsing to standard types.
|
||||
"""
|
||||
blocks: list[types.ContentBlock] = []
|
||||
if isinstance(self.content, str):
|
||||
if self.content:
|
||||
blocks.append({"type": "text", "text": self.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
elif isinstance(self.content, list):
|
||||
for item in self.content:
|
||||
if isinstance(item, str):
|
||||
blocks.append({"type": "text", "text": item})
|
||||
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
blocks.append(cast("types.TextContentBlock", item))
|
||||
elif item_type == "tool_call":
|
||||
blocks.append(cast("types.ToolCallContentBlock", item))
|
||||
elif item_type == "reasoning":
|
||||
blocks.append(cast("types.ReasoningContentBlock", item))
|
||||
elif item_type == "non_standard":
|
||||
blocks.append(cast("types.NonStandardContentBlock", item))
|
||||
elif source_type := item.get("source_type"):
|
||||
if source_type == "url":
|
||||
blocks.append(cast("types.URLContentBlock", item))
|
||||
elif source_type == "base64":
|
||||
blocks.append(cast("types.Base64ContentBlock", item))
|
||||
elif source_type == "text":
|
||||
blocks.append(cast("types.PlainTextContentBlock", item))
|
||||
elif source_type == "id":
|
||||
blocks.append(cast("types.IDContentBlock", item))
|
||||
else:
|
||||
msg = f"Unknown source_type {source_type} in content block."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unknown content block type {item_type}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Add from tool_calls if missing from content
|
||||
content_tool_call_ids = {
|
||||
block.get("id")
|
||||
for block in self.content
|
||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
||||
}
|
||||
for tool_call in self.tool_calls:
|
||||
if (id_ := tool_call.get("id")) and id_ not in content_tool_call_ids:
|
||||
tool_call_block: types.ToolCallContentBlock = {
|
||||
"type": "tool_call",
|
||||
"id": id_,
|
||||
}
|
||||
blocks.append(tool_call_block)
|
||||
|
||||
return blocks
|
||||
|
||||
# TODO: remove this logic if possible, reducing breaking nature of changes
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import ContentBlock
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@ -24,7 +23,7 @@ class BaseMessage(Serializable):
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: Union[str, list[Union[str, ContentBlock, dict]]]
|
||||
content: Union[str, list[Union[str, dict]]]
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
@ -6,21 +6,30 @@ EXPECTED_ALL = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"AnyMessage",
|
||||
"Base64ContentBlock",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ContentBlock",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"DocumentCitation",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"InvalidToolCall",
|
||||
"NonStandardAnnotation",
|
||||
"NonStandardContentBlock",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"TextContentBlock",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolCallContentBlock",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"UrlCitation",
|
||||
"ReasoningContentBlock",
|
||||
"RemoveMessage",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
|
Loading…
Reference in New Issue
Block a user