messages v1: some nits, and use create_text_block()

This commit is contained in:
Mason Daugherty 2025-07-31 17:07:42 -04:00
parent 45533fc875
commit 2cb48b685f
No known key found for this signature in database

View File

@ -20,6 +20,7 @@ from langchain_core.messages.ai import (
add_usage, add_usage,
) )
from langchain_core.messages.base import merge_content from langchain_core.messages.base import merge_content
from langchain_core.messages.content_blocks import create_text_block
from langchain_core.messages.tool import ToolCallChunk from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call as create_tool_call
@ -49,13 +50,13 @@ class ResponseMetadata(TypedDict, total=False):
Contains additional information returned by the provider, such as Contains additional information returned by the provider, such as
response headers, service tiers, log probabilities, system fingerprints, etc. response headers, service tiers, log probabilities, system fingerprints, etc.
Extra keys are permitted from what is typed here (via `total=False`), allowing Extra keys are permitted from what is typed here (via ``total=False``), allowing
for provider-specific metadata to be included without breaking the type for provider-specific metadata to be included without breaking the type
definition. definition.
""" """
model_provider: str model_provider: str
"""Name and version of the provider that created the message (e.g., openai).""" """Name and version of the provider that created the message (ex: ``'openai'``)."""
model_name: str model_name: str
"""Name of the model that generated the message.""" """Name of the model that generated the message."""
@ -69,8 +70,8 @@ class AIMessage:
and metadata about the generation process. and metadata about the generation process.
Attributes: Attributes:
type: Message type identifier, always ``'ai'``.
id: Unique identifier for the message. id: Unique identifier for the message.
type: Message type identifier, always "ai".
name: Optional human-readable name for the message. name: Optional human-readable name for the message.
lc_version: Encoding version for the message. lc_version: Encoding version for the message.
content: List of content blocks containing the message data. content: List of content blocks containing the message data.
@ -151,7 +152,7 @@ class AIMessage:
parsed: Optional auto-parsed message contents, if applicable. parsed: Optional auto-parsed message contents, if applicable.
""" """
if isinstance(content, str): if isinstance(content, str):
self.content = [{"type": "text", "text": content}] self.content = [create_text_block(content)]
else: else:
self.content = content self.content = content
@ -170,21 +171,21 @@ class AIMessage:
content_tool_calls = { content_tool_calls = {
block["id"] block["id"]
for block in self.content for block in self.content
if block["type"] == "tool_call" and "id" in block if block.get("type") == "tool_call" and "id" in block
} }
for tool_call in tool_calls: for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls: if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue continue
self.content.append(tool_call) self.content.append(tool_call)
self._tool_calls = [ self._tool_calls = [
block for block in self.content if block["type"] == "tool_call" block for block in self.content if block.get("type") == "tool_call"
] ]
self.invalid_tool_calls = invalid_tool_calls or [] self.invalid_tool_calls = invalid_tool_calls or []
@property @property
def text(self) -> Optional[str]: def text(self) -> Optional[str]:
"""Extract all text content from the AI message as a string.""" """Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if block["type"] == "text"] text_blocks = [block for block in self.content if block.get("type") == "text"]
if text_blocks: if text_blocks:
return "".join(block["text"] for block in text_blocks) return "".join(block["text"] for block in text_blocks)
return None return None
@ -213,8 +214,8 @@ class AIMessageChunk(AIMessage):
during streaming generation. Contains partial content and metadata. during streaming generation. Contains partial content and metadata.
Attributes: Attributes:
type: Message type identifier, always ``'ai_chunk'``.
id: Unique identifier for the message chunk. id: Unique identifier for the message chunk.
type: Message type identifier, always "ai_chunk".
name: Optional human-readable name for the message. name: Optional human-readable name for the message.
content: List of content blocks containing partial message data. content: List of content blocks containing partial message data.
tool_call_chunks: Optional list of partial tool call data. tool_call_chunks: Optional list of partial tool call data.
@ -361,7 +362,7 @@ class AIMessageChunk(AIMessage):
text_blocks = [ text_blocks = [
block block
for block in self.content for block in self.content
if block["type"] == "reasoning" and "reasoning" in block if block.get("type") == "reasoning" and "reasoning" in block
] ]
if text_blocks: if text_blocks:
return "".join(block["reasoning"] for block in text_blocks) return "".join(block["reasoning"] for block in text_blocks)
@ -372,7 +373,9 @@ class AIMessageChunk(AIMessage):
"""Get the tool calls made by the AI.""" """Get the tool calls made by the AI."""
if self._tool_calls: if self._tool_calls:
return self._tool_calls return self._tool_calls
tool_calls = [block for block in self.content if block["type"] == "tool_call"] tool_calls = [
block for block in self.content if block.get("type") == "tool_call"
]
if tool_calls: if tool_calls:
self._tool_calls = tool_calls self._tool_calls = tool_calls
return self._tool_calls return self._tool_calls
@ -383,7 +386,7 @@ class AIMessageChunk(AIMessage):
self._tool_calls = value self._tool_calls = value
def __add__(self, other: Any) -> "AIMessageChunk": def __add__(self, other: Any) -> "AIMessageChunk":
"""Add AIMessageChunk to this one.""" """Add ``AIMessageChunk`` to this one."""
if isinstance(other, AIMessageChunk): if isinstance(other, AIMessageChunk):
return add_ai_message_chunks(self, other) return add_ai_message_chunks(self, other)
if isinstance(other, (list, tuple)) and all( if isinstance(other, (list, tuple)) and all(
@ -394,7 +397,7 @@ class AIMessageChunk(AIMessage):
raise NotImplementedError(error_msg) raise NotImplementedError(error_msg)
def to_message(self) -> "AIMessage": def to_message(self) -> "AIMessage":
"""Convert this AIMessageChunk to an AIMessage.""" """Convert this ``AIMessageChunk`` to an AIMessage."""
return AIMessage( return AIMessage(
content=self.content, content=self.content,
id=self.id, id=self.id,
@ -411,7 +414,7 @@ class AIMessageChunk(AIMessage):
def add_ai_message_chunks( def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk: ) -> AIMessageChunk:
"""Add multiple AIMessageChunks together.""" """Add multiple ``AIMessageChunks`` together."""
if not others: if not others:
return left return left
content = merge_content( content = merge_content(
@ -498,10 +501,10 @@ class HumanMessage:
or other content types like images. or other content types like images.
Attributes: Attributes:
type: Message type identifier, always ``'human'``.
id: Unique identifier for the message. id: Unique identifier for the message.
content: List of content blocks containing the user's input. content: List of content blocks containing the user's input.
name: Optional human-readable name for the message. name: Optional human-readable name for the message.
type: Message type identifier, always "human".
""" """
id: str id: str
@ -558,7 +561,7 @@ class HumanMessage:
Concatenated string of all text blocks in the message. Concatenated string of all text blocks in the message.
""" """
return "".join( return "".join(
block["text"] for block in self.content if block["type"] == "text" block["text"] for block in self.content if block.get("type") == "text"
) )
@ -570,9 +573,9 @@ class SystemMessage:
behavior and understanding of the conversation. behavior and understanding of the conversation.
Attributes: Attributes:
type: Message type identifier, always ``'system'``.
id: Unique identifier for the message. id: Unique identifier for the message.
content: List of content blocks containing system instructions. content: List of content blocks containing system instructions.
type: Message type identifier, always "system".
""" """
id: str id: str
@ -604,7 +607,7 @@ class SystemMessage:
custom_role: Optional[str] = None custom_role: Optional[str] = None
"""If provided, a custom role for the system message. """If provided, a custom role for the system message.
Example: ``"developer"``. Example: ``'developer'``.
Integration packages may use this field to assign the system message role if it Integration packages may use this field to assign the system message role if it
contains a recognized value. contains a recognized value.
@ -637,7 +640,7 @@ class SystemMessage:
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the system message.""" """Extract all text content from the system message."""
return "".join( return "".join(
block["text"] for block in self.content if block["type"] == "text" block["text"] for block in self.content if block.get("type") == "text"
) )
@ -649,12 +652,12 @@ class ToolMessage:
including the result data and execution status. including the result data and execution status.
Attributes: Attributes:
type: Message type identifier, always ``'tool'``.
id: Unique identifier for the message. id: Unique identifier for the message.
tool_call_id: ID of the tool call this message responds to. tool_call_id: ID of the tool call this message responds to.
content: The result content from tool execution. content: The result content from tool execution.
artifact: Optional app-side payload not intended for the model. artifact: Optional app-side payload not intended for the model.
status: Execution status ("success" or "error"). status: Execution status ("success" or "error").
type: Message type identifier, always "tool".
""" """
id: str id: str
@ -713,7 +716,7 @@ class ToolMessage:
id: Optional unique identifier for the message. id: Optional unique identifier for the message.
name: Optional human-readable name for the message. name: Optional human-readable name for the message.
artifact: Optional app-side payload not intended for the model. artifact: Optional app-side payload not intended for the model.
status: Execution status ("success" or "error"). status: Execution status (``'success'`` or ``'error'``).
""" """
self.id = _ensure_id(id) self.id = _ensure_id(id)
self.tool_call_id = tool_call_id self.tool_call_id = tool_call_id
@ -729,7 +732,7 @@ class ToolMessage:
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the tool message.""" """Extract all text content from the tool message."""
return "".join( return "".join(
block["text"] for block in self.content if block["type"] == "text" block["text"] for block in self.content if block.get("type") == "text"
) )
def __post_init__(self) -> None: def __post_init__(self) -> None: