start on openai

This commit is contained in:
Chester Curme
2025-07-24 17:12:22 -04:00
parent 041b196145
commit 4899857042
7 changed files with 482 additions and 566 deletions

View File

@@ -307,7 +307,7 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
id=message.id,
name=message.name,
tool_calls=message.tool_calls,
response_metadata=cast(dict, message.response_metadata),
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, AIMessageChunkV1):
return AIMessageChunk(
@@ -315,7 +315,7 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
id=message.id,
name=message.name,
tool_call_chunks=message.tool_call_chunks,
response_metadata=cast(dict, message.response_metadata),
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, HumanMessageV1):
return HumanMessage(

View File

@@ -5,6 +5,8 @@ import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
from pydantic import BaseModel
import langchain_core.messages.content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
from langchain_core.messages.base import merge_content
@@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str:
return id_val or str(uuid.uuid4())
class Provider(TypedDict):
"""Information about the provider that generated the message.
class ResponseMetadata(TypedDict, total=False):
"""Metadata about the response from the AI provider.
Contains metadata about the AI provider and model used to generate content.
Contains additional information returned by the provider, such as
response headers, service tiers, log probabilities, system fingerprints, etc.
Attributes:
name: Name and version of the provider that created the content block.
model_name: Name of the model that generated the content block.
Extra keys are permitted from what is typed here.
"""
name: str
"""Name and version of the provider that created the content block."""
model_provider: str
"""Name and version of the provider that created the message (e.g., openai)."""
model_name: str
"""Name of the model that generated the content block."""
"""Name of the model that generated the message."""
@dataclass
@@ -91,21 +93,29 @@ class AIMessage:
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts."""
response_metadata: dict = field(default_factory=dict)
response_metadata: ResponseMetadata = field(
default_factory=lambda: ResponseMetadata()
)
"""Metadata about the response.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[dict] = None,
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
):
"""Initialize an AI message.
@@ -116,6 +126,11 @@ class AIMessage:
lc_version: Encoding version for the message.
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_calls: Optional list of tool calls made by the AI. Tool calls should
generally be included in message content. If passed on init, they will
be added to the content list.
invalid_tool_calls: Optional list of tool calls that failed validation.
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
@@ -126,13 +141,27 @@ class AIMessage:
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None:
self.response_metadata = {}
else:
self.response_metadata = response_metadata
self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = []
# Add tool calls to content if provided on init
if tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if block["type"] == "tool_call" and "id" in block
}
for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
self._tool_calls = [
block for block in self.content if block["type"] == "tool_call"
]
self.invalid_tool_calls = invalid_tool_calls or []
@property
def text(self) -> Optional[str]:
@@ -150,7 +179,7 @@ class AIMessage:
tool_calls = [block for block in self.content if block["type"] == "tool_call"]
if tool_calls:
self._tool_calls = tool_calls
return self._tool_calls
return [block for block in self.content if block["type"] == "tool_call"]
@tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None:
@@ -202,13 +231,16 @@ class AIMessageChunk:
These data represent incremental usage statistics, as opposed to a running total.
"""
response_metadata: dict = field(init=False)
response_metadata: ResponseMetadata = field(init=False)
"""Metadata about the response chunk.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
def __init__(
@@ -217,9 +249,10 @@ class AIMessageChunk:
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[dict] = None,
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
):
"""Initialize an AI message.
@@ -231,6 +264,7 @@ class AIMessageChunk:
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_call_chunks: Optional list of partial tool call data.
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content, "index": 0}]
@@ -241,6 +275,7 @@ class AIMessageChunk:
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None:
self.response_metadata = {}
else:
@@ -251,7 +286,7 @@ class AIMessageChunk:
self.tool_call_chunks = tool_call_chunks
self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = []
self.invalid_tool_calls: list[types.InvalidToolCall] = []
self._init_tool_calls()
def _init_tool_calls(self) -> None:
@@ -264,7 +299,7 @@ class AIMessageChunk:
ValueError: If the tool call chunks are malformed.
"""
self._tool_calls = []
self._invalid_tool_calls = []
self.invalid_tool_calls = []
if not self.tool_call_chunks:
if self._tool_calls:
self.tool_call_chunks = [
@@ -276,14 +311,14 @@ class AIMessageChunk:
)
for tc in self._tool_calls
]
if self._invalid_tool_calls:
if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend(
[
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in self._invalid_tool_calls
for tc in self.invalid_tool_calls
]
)
self.tool_call_chunks = tool_call_chunks
@@ -317,7 +352,7 @@ class AIMessageChunk:
except Exception:
add_chunk_to_invalid_tool_calls(chunk)
self._tool_calls = tool_calls
self._invalid_tool_calls = invalid_tool_calls
self.invalid_tool_calls = invalid_tool_calls
@property
def text(self) -> Optional[str]:
@@ -361,6 +396,20 @@ class AIMessageChunk:
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
raise NotImplementedError(error_msg)
def to_message(self) -> "AIMessage":
"""Convert this AIMessageChunk to an AIMessage."""
return AIMessage(
content=self.content,
id=self.id,
name=self.name,
lc_version=self.lc_version,
response_metadata=self.response_metadata,
usage_metadata=self.usage_metadata,
tool_calls=self.tool_calls,
invalid_tool_calls=self.invalid_tool_calls,
parsed=self.parsed,
)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
@@ -371,7 +420,8 @@ def add_ai_message_chunks(
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
)
response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others)
cast("dict", left.response_metadata),
*(cast("dict", o.response_metadata) for o in others),
)
# Merge tool call chunks
@@ -398,6 +448,15 @@ def add_ai_message_chunks(
else:
usage_metadata = None
# Parsed
# 'parsed' always represents an aggregation not an incremental value, so the last
# non-null value is kept.
parsed = None
for m in reversed([left, *others]):
if m.parsed is not None:
parsed = m.parsed
break
chunk_id = None
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first non-run-* id
@@ -415,8 +474,9 @@ def add_ai_message_chunks(
return left.__class__(
content=cast("list[types.ContentBlock]", content),
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
response_metadata=cast("ResponseMetadata", response_metadata),
usage_metadata=usage_metadata,
parsed=parsed,
id=chunk_id,
)
@@ -453,19 +513,25 @@ class HumanMessage:
"""
def __init__(
self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a human message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
def text(self) -> str:
"""Extract all text content from the message.
@@ -495,20 +561,47 @@ class SystemMessage:
content: list[types.ContentBlock]
type: Literal["system"] = "system"
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
custom_role: Optional[str] = None
"""If provided, a custom role for the system message.
Example: ``"developer"``.
Integration packages may use this field to assign the system message role if it
contains a recognized value.
"""
def __init__(
self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
custom_role: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a system message.
"""Initialize a human message.
Args:
content: System instructions as string or list of content blocks.
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
custom_role: If provided, a custom role for the system message.
name: Optional human-readable name for the message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.custom_role = custom_role
self.name = name
def text(self) -> str:
"""Extract all text content from the system message."""
@@ -535,11 +628,51 @@ class ToolMessage:
id: str
tool_call_id: str
content: list[dict[str, Any]]
content: list[types.ContentBlock]
artifact: Optional[Any] = None # App-side payload not for the model
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
status: Literal["success", "error"] = "success"
type: Literal["tool"] = "tool"
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
tool_call_id: str,
*,
id: Optional[str] = None,
name: Optional[str] = None,
artifact: Optional[Any] = None,
status: Literal["success", "error"] = "success",
):
"""Initialize a human message.
Args:
content: Message content as string or list of content blocks.
tool_call_id: ID of the tool call this message responds to.
id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
artifact: Optional app-side payload not intended for the model.
status: Execution status ("success" or "error").
"""
self.id = _ensure_id(id)
self.tool_call_id = tool_call_id
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
self.artifact = artifact
self.status = status
@property
def text(self) -> str:
"""Extract all text content from the tool message."""

View File

@@ -14,16 +14,18 @@ from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
get_buffer_string,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
@@ -40,7 +42,7 @@ def _convert_to_v1(message: BaseMessage) -> MessageV1:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(block)
content.append(cast("types.ContentBlock", block))
else:
pass
@@ -52,7 +54,7 @@ def _convert_to_v1(message: BaseMessage) -> MessageV1:
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
response_metadata=cast("ResponseMetadata", message.response_metadata),
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
@@ -92,8 +94,18 @@ class PromptValue(Serializable, ABC):
def to_string(self) -> str:
"""Return prompt value as string."""
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
@abstractmethod
def to_messages(self) -> list[BaseMessage]:
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of Messages."""
@@ -117,10 +129,6 @@ class StringPromptValue(PromptValue):
"""Return prompt as string."""
return self.text
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
@@ -131,12 +139,8 @@ class StringPromptValue(PromptValue):
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as messages."""
if output_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)]
@@ -165,7 +169,7 @@ class ChatPromptValue(PromptValue):
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[list[BaseMessage], list[MessageV1]]:
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of messages.
Args:
@@ -207,8 +211,26 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string."""
return self.image_url["url"]
def to_messages(self) -> list[BaseMessage]:
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt (image URL) as messages."""
if output_version == "v1":
block: types.ImageContentBlock = {
"type": "image",
"url": self.image_url["url"],
}
if "detail" in self.image_url:
block["detail"] = self.image_url["detail"]
return [HumanMessageV1(content=[block])]
return [HumanMessage(content=[cast("dict", self.image_url)])]

View File

@@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"
disable_error_code = ["typeddict-unknown-key"]
# TODO: activate for 'strict' checking
disallow_any_generics = "False"