feat(openai): v1 message format support (#32296)

This commit is contained in:
ccurme 2025-07-28 19:42:26 -03:00 committed by GitHub
parent 7166adce1f
commit c15e55b33c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4682 additions and 299 deletions

View File

@ -2,8 +2,7 @@
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import Mapping, Sequence
from functools import cache
from typing import (
@ -26,7 +25,6 @@ from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue
@ -166,7 +164,6 @@ class BaseLanguageModel(
list[AnyMessage],
]
@abstractmethod
def generate_prompt(
self,
prompts: list[PromptValue],
@ -201,7 +198,6 @@ class BaseLanguageModel(
prompt and additional model provider-specific output.
"""
@abstractmethod
async def agenerate_prompt(
self,
prompts: list[PromptValue],
@ -245,7 +241,6 @@ class BaseLanguageModel(
raise NotImplementedError
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -266,7 +261,6 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict_messages(
self,
messages: list[BaseMessage],
@ -291,7 +285,6 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -312,7 +305,6 @@ class BaseLanguageModel(
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict_messages(
self,
messages: list[BaseMessage],
@ -368,33 +360,6 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod
def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility.

View File

@ -55,12 +55,11 @@ from langchain_core.messages import (
HumanMessage,
convert_to_messages,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
message_chunk_to_message,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@ -222,23 +221,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for chat models.
@ -1370,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
starter_dict["_type"] = self._llm_type
return starter_dict
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def bind_tools(
self,
tools: Sequence[

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy
import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from operator import itemgetter
@ -38,11 +39,14 @@ from langchain_core.language_models.base import (
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
)
from langchain_core.messages.utils import convert_to_messages_v1
from langchain_core.messages.utils import (
_convert_from_v1_message,
convert_to_messages_v1,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
@ -735,7 +739,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
*,
tool_choice: Optional[Union[str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
) -> Runnable[LanguageModelInput, AIMessageV1]:
"""Bind tools to the model.
Args:
@ -899,6 +903,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser
def get_num_tokens_from_messages(
self,
messages: list[MessageV1],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
messages_v0 = [_convert_from_v1_message(message) for message in messages]
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages_v0)
def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk],

View File

@ -706,6 +706,7 @@ ToolContentBlock = Union[
ContentBlock = Union[
TextContentBlock,
ToolCall,
InvalidToolCall,
ReasoningContentBlock,
NonStandardContentBlock,
DataContentBlock,

View File

@ -384,38 +384,37 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
Returns:
BaseMessage: Converted message instance.
"""
# type ignores here are because AIMessageV1.content is a list of dicts.
# AIMessageV0.content expects str or list[str | dict].
content = cast("Union[str, list[str | dict]]", message.content)
if isinstance(message, AIMessageV1):
return AIMessage(
content=message.content, # type: ignore[arg-type]
content=content,
id=message.id,
name=message.name,
tool_calls=message.tool_calls,
response_metadata=message.response_metadata,
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, AIMessageChunkV1):
return AIMessageChunk(
content=message.content, # type: ignore[arg-type]
content=content,
id=message.id,
name=message.name,
tool_call_chunks=message.tool_call_chunks,
response_metadata=message.response_metadata,
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, HumanMessageV1):
return HumanMessage(
content=message.content, # type: ignore[arg-type]
content=content,
id=message.id,
name=message.name,
)
if isinstance(message, SystemMessageV1):
return SystemMessage(
content=message.content, # type: ignore[arg-type]
content=content,
id=message.id,
)
if isinstance(message, ToolMessageV1):
return ToolMessage(
content=message.content, # type: ignore[arg-type]
content=content,
id=message.id,
)
message = f"Unsupported message type: {type(message)}"
@ -501,7 +500,10 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
ValueError: if the message dict does not contain the required keys.
"""
if isinstance(message, MessageV1Types):
message_ = message
if isinstance(message, AIMessageChunkV1):
message_ = message.to_message()
else:
message_ = message
elif isinstance(message, str):
message_ = _create_message_from_message_type_v1("human", message)
elif isinstance(message, Sequence) and len(message) == 2:

View File

@ -5,6 +5,8 @@ import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
from pydantic import BaseModel
import langchain_core.messages.content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
from langchain_core.messages.base import merge_content
@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str:
return id_val or str(uuid.uuid4())
class Provider(TypedDict):
"""Information about the provider that generated the message.
class ResponseMetadata(TypedDict, total=False):
"""Metadata about the response from the AI provider.
Contains metadata about the AI provider and model used to generate content.
Contains additional information returned by the provider, such as
response headers, service tiers, log probabilities, system fingerprints, etc.
Attributes:
name: Name and version of the provider that created the content block.
model_name: Name of the model that generated the content block.
Extra keys are permitted from what is typed here.
"""
name: str
"""Name and version of the provider that created the content block."""
model_provider: str
"""Name and version of the provider that created the message (e.g., openai)."""
model_name: str
"""Name of the model that generated the content block."""
"""Name of the model that generated the message."""
@dataclass
@ -91,21 +93,29 @@ class AIMessage:
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts."""
response_metadata: dict = field(default_factory=dict)
response_metadata: ResponseMetadata = field(
default_factory=lambda: ResponseMetadata()
)
"""Metadata about the response.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[dict] = None,
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
):
"""Initialize an AI message.
@ -116,6 +126,11 @@ class AIMessage:
lc_version: Encoding version for the message.
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_calls: Optional list of tool calls made by the AI. Tool calls should
generally be included in message content. If passed on init, they will
be added to the content list.
invalid_tool_calls: Optional list of tool calls that failed validation.
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
@ -126,13 +141,27 @@ class AIMessage:
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None:
self.response_metadata = {}
else:
self.response_metadata = response_metadata
self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = []
# Add tool calls to content if provided on init
if tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if block["type"] == "tool_call" and "id" in block
}
for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
self._tool_calls = [
block for block in self.content if block["type"] == "tool_call"
]
self.invalid_tool_calls = invalid_tool_calls or []
@property
def text(self) -> Optional[str]:
@ -150,7 +179,7 @@ class AIMessage:
tool_calls = [block for block in self.content if block["type"] == "tool_call"]
if tool_calls:
self._tool_calls = tool_calls
return self._tool_calls
return [block for block in self.content if block["type"] == "tool_call"]
@tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None:
@ -202,13 +231,16 @@ class AIMessageChunk:
These data represent incremental usage statistics, as opposed to a running total.
"""
response_metadata: dict = field(init=False)
response_metadata: ResponseMetadata = field(init=False)
"""Metadata about the response chunk.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
def __init__(
@ -217,9 +249,10 @@ class AIMessageChunk:
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[dict] = None,
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
):
"""Initialize an AI message.
@ -231,6 +264,7 @@ class AIMessageChunk:
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_call_chunks: Optional list of partial tool call data.
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content, "index": 0}]
@ -241,6 +275,7 @@ class AIMessageChunk:
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None:
self.response_metadata = {}
else:
@ -251,7 +286,7 @@ class AIMessageChunk:
self.tool_call_chunks = tool_call_chunks
self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = []
self.invalid_tool_calls: list[types.InvalidToolCall] = []
self._init_tool_calls()
def _init_tool_calls(self) -> None:
@ -264,7 +299,7 @@ class AIMessageChunk:
ValueError: If the tool call chunks are malformed.
"""
self._tool_calls = []
self._invalid_tool_calls = []
self.invalid_tool_calls = []
if not self.tool_call_chunks:
if self._tool_calls:
self.tool_call_chunks = [
@ -276,14 +311,14 @@ class AIMessageChunk:
)
for tc in self._tool_calls
]
if self._invalid_tool_calls:
if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend(
[
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in self._invalid_tool_calls
for tc in self.invalid_tool_calls
]
)
self.tool_call_chunks = tool_call_chunks
@ -294,9 +329,9 @@ class AIMessageChunk:
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
invalid_tool_calls.append(
create_invalid_tool_call(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
name=chunk.get("name", ""),
args=chunk.get("args", ""),
id=chunk.get("id", ""),
error=None,
)
)
@ -307,9 +342,9 @@ class AIMessageChunk:
if isinstance(args_, dict):
tool_calls.append(
create_tool_call(
name=chunk["name"] or "",
name=chunk.get("name", ""),
args=args_,
id=chunk["id"],
id=chunk.get("id", ""),
)
)
else:
@ -317,7 +352,7 @@ class AIMessageChunk:
except Exception:
add_chunk_to_invalid_tool_calls(chunk)
self._tool_calls = tool_calls
self._invalid_tool_calls = invalid_tool_calls
self.invalid_tool_calls = invalid_tool_calls
@property
def text(self) -> Optional[str]:
@ -361,6 +396,20 @@ class AIMessageChunk:
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
raise NotImplementedError(error_msg)
def to_message(self) -> "AIMessage":
"""Convert this AIMessageChunk to an AIMessage."""
return AIMessage(
content=self.content,
id=self.id,
name=self.name,
lc_version=self.lc_version,
response_metadata=self.response_metadata,
usage_metadata=self.usage_metadata,
tool_calls=self.tool_calls,
invalid_tool_calls=self.invalid_tool_calls,
parsed=self.parsed,
)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
@ -373,7 +422,8 @@ def add_ai_message_chunks(
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
)
response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others)
cast("dict", left.response_metadata),
*(cast("dict", o.response_metadata) for o in others),
)
# Merge tool call chunks
@ -400,6 +450,15 @@ def add_ai_message_chunks(
else:
usage_metadata = None
# Parsed
# 'parsed' always represents an aggregation not an incremental value, so the last
# non-null value is kept.
parsed = None
for m in reversed([left, *others]):
if m.parsed is not None:
parsed = m.parsed
break
chunk_id = None
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first non-run-* id
@ -417,8 +476,9 @@ def add_ai_message_chunks(
return left.__class__(
content=cast("list[types.ContentBlock]", content),
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
response_metadata=cast("ResponseMetadata", response_metadata),
usage_metadata=usage_metadata,
parsed=parsed,
id=chunk_id,
)
@ -455,19 +515,25 @@ class HumanMessage:
"""
def __init__(
self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a human message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
def text(self) -> str:
"""Extract all text content from the message.
@ -497,20 +563,47 @@ class SystemMessage:
content: list[types.ContentBlock]
type: Literal["system"] = "system"
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
custom_role: Optional[str] = None
"""If provided, a custom role for the system message.
Example: ``"developer"``.
Integration packages may use this field to assign the system message role if it
contains a recognized value.
"""
def __init__(
self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
custom_role: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a system message.
"""Initialize a human message.
Args:
content: System instructions as string or list of content blocks.
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
custom_role: If provided, a custom role for the system message.
name: Optional human-readable name for the message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.custom_role = custom_role
self.name = name
def text(self) -> str:
"""Extract all text content from the system message."""
@ -537,11 +630,51 @@ class ToolMessage:
id: str
tool_call_id: str
content: list[dict[str, Any]]
content: list[types.ContentBlock]
artifact: Optional[Any] = None # App-side payload not for the model
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
status: Literal["success", "error"] = "success"
type: Literal["tool"] = "tool"
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
tool_call_id: str,
*,
id: Optional[str] = None,
name: Optional[str] = None,
artifact: Optional[Any] = None,
status: Literal["success", "error"] = "success",
):
"""Initialize a human message.
Args:
content: Message content as string or list of content blocks.
tool_call_id: ID of the tool call this message responds to.
id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
artifact: Optional app-side payload not intended for the model.
status: Execution status ("success" or "error").
"""
self.id = _ensure_id(id)
self.tool_call_id = tool_call_id
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
self.artifact = artifact
self.status = status
@property
def text(self) -> str:
"""Extract all text content from the tool message."""

View File

@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
@ -26,7 +26,7 @@ def parse_tool_call(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[dict[str, Any]]:
) -> Optional[ToolCall]:
"""Parse a single tool call.
Args:

View File

@ -8,17 +8,65 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Literal, cast
from typing import Literal, Union, cast
from typing_extensions import TypedDict
from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
get_buffer_string,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
def _convert_to_v1(message: BaseMessage) -> MessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
else:
content = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(cast("types.ContentBlock", block))
else:
pass
if isinstance(message, HumanMessage):
return HumanMessageV1(content=content)
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=cast("ResponseMetadata", message.response_metadata),
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
return SystemMessageV1(content=content)
if isinstance(message, ToolMessage):
return ToolMessageV1(
tool_call_id=message.tool_call_id,
content=content,
artifact=message.artifact,
)
error_message = f"Unsupported message type: {type(message)}"
raise TypeError(error_message)
class PromptValue(Serializable, ABC):
@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
def to_string(self) -> str:
"""Return prompt value as string."""
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
@abstractmethod
def to_messages(self) -> list[BaseMessage]:
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of Messages."""
@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
"""Return prompt as string."""
return self.text
def to_messages(self) -> list[BaseMessage]:
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as messages."""
if output_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)]
@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of messages."""
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [_convert_to_v1(m) for m in self.messages]
return list(self.messages)
@classmethod
@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string."""
return self.image_url["url"]
def to_messages(self) -> list[BaseMessage]:
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt (image URL) as messages."""
if output_version == "v1":
block: types.ImageContentBlock = {
"type": "image",
"url": self.image_url["url"],
}
if "detail" in self.image_url:
block["detail"] = self.image_url["detail"]
return [HumanMessageV1(content=[block])]
return [HumanMessage(content=[cast("dict", self.image_url)])]

View File

@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"
disable_error_code = ["typeddict-unknown-key"]
# TODO: activate for 'strict' checking
disallow_any_generics = "False"

View File

@ -34,6 +34,7 @@ from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.utils._merge import merge_lists
@ -197,7 +198,7 @@ def test_message_chunks() -> None:
assert (meaningful_id + default_id).id == "msg_def456"
def test_message_chunks_v2() -> None:
def test_message_chunks_v1() -> None:
left = AIMessageChunkV1("foo ", id="abc")
right = AIMessageChunkV1("bar")
expected = AIMessageChunkV1("foo bar", id="abc")
@ -230,7 +231,19 @@ def test_message_chunks_v2() -> None:
)
],
)
assert one + two + three == expected
result = one + two + three
assert result == expected
assert result.to_message() == AIMessageV1(
content=[
{
"name": "tool1",
"args": {"arg1": "value}"},
"id": "1",
"type": "tool_call",
}
]
)
assert (
AIMessageChunkV1(
@ -1326,6 +1339,7 @@ def test_known_block_types() -> None:
"text",
"text-plain",
"tool_call",
"invalid_tool_call",
"reasoning",
"non_standard",
"image",

View File

@ -1,10 +1,11 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAIV1
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
__all__ = [
"OpenAI",
"ChatOpenAI",
"ChatOpenAIV1",
"OpenAIEmbeddings",
"AzureOpenAI",
"AzureChatOpenAI",

View File

@ -1,4 +1,5 @@
from langchain_openai.chat_models.azure import AzureChatOpenAI
from langchain_openai.chat_models.base import ChatOpenAI
from langchain_openai.chat_models.base_v1 import ChatOpenAI as ChatOpenAIV1
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAIV1"]

View File

@ -66,11 +66,14 @@ For backwards compatibility, this module provides functions to convert between t
formats. The functions are used internally by ChatOpenAI.
""" # noqa: E501
import copy
import json
from collections.abc import Iterable, Iterator
from typing import Any, Literal, Union, cast
from typing import Any, Literal, Optional, Union, cast
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
from langchain_core.messages import content_blocks as types
from langchain_core.messages.v1 import AIMessage as AIMessageV1
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
@ -289,25 +292,21 @@ def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessa
return cast(AIMessageChunk, result)
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1:
"""Convert a v1 message to the Chat Completions format."""
if isinstance(message.content, list):
new_content: list = []
for block in message.content:
if isinstance(block, dict):
block_type = block.get("type")
if block_type == "text":
# Strip annotations
new_content.append({"type": "text", "text": block["text"]})
elif block_type in ("reasoning", "tool_call"):
pass
else:
new_content.append(block)
else:
new_content.append(block)
return message.model_copy(update={"content": new_content})
new_content: list[types.ContentBlock] = []
for block in message.content:
if block["type"] == "text":
# Strip annotations
new_content.append({"type": "text", "text": block["text"]})
elif block["type"] in ("reasoning", "tool_call"):
pass
else:
new_content.append(block)
new_message = copy.copy(message)
new_message.content = new_content
return message
return new_message
# v1 / Responses
@ -319,17 +318,18 @@ def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
for field in ("end_index", "start_index", "title"):
if field in annotation:
url_citation[field] = annotation[field]
url_citation["type"] = "url_citation"
url_citation["type"] = "citation"
url_citation["url"] = annotation["url"]
return url_citation
elif annotation_type == "file_citation":
document_citation = {"type": "document_citation"}
document_citation = {"type": "citation"}
if "filename" in annotation:
document_citation["title"] = annotation["filename"]
for field in ("file_id", "index"): # OpenAI-specific
if field in annotation:
document_citation[field] = annotation[field]
if "file_id" in annotation:
document_citation["file_id"] = annotation["file_id"]
if "index" in annotation:
document_citation["file_index"] = annotation["index"]
return document_citation
# TODO: standardise container_file_citation?
@ -367,13 +367,15 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
yield new_block
def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
def _convert_to_v1_from_responses(
content: list[dict[str, Any]],
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
) -> list[types.ContentBlock]:
"""Mutate a Responses message to v1 format."""
if not isinstance(message.content, list):
return message
def _iter_blocks() -> Iterable[dict[str, Any]]:
for block in message.content:
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
@ -409,13 +411,24 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
yield new_block
elif block_type == "function_call":
new_block = {"type": "tool_call", "id": block.get("call_id", "")}
if "id" in block:
new_block["item_id"] = block["id"]
for extra_key in ("arguments", "name", "index"):
if extra_key in block:
new_block[extra_key] = block[extra_key]
yield new_block
new_block = None
call_id = block.get("call_id", "")
if call_id:
for tool_call in tool_calls or []:
if tool_call.get("id") == call_id:
new_block = tool_call.copy()
break
else:
for invalid_tool_call in invalid_tool_calls or []:
if invalid_tool_call.get("id") == call_id:
new_block = invalid_tool_call.copy()
break
if new_block:
if "id" in block:
new_block["item_id"] = block["id"]
if "index" in block:
new_block["index"] = block["index"]
yield new_block
elif block_type == "web_search_call":
web_search_call = {"type": "web_search_call", "id": block["id"]}
@ -485,28 +498,26 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
new_block["index"] = new_block["value"].pop("index")
yield new_block
# Replace the list with the fully converted one
message.content = list(_iter_blocks())
return message
return list(_iter_blocks())
def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]:
annotation_type = annotation.get("type")
def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
if annotation["type"] == "citation":
if "url" in annotation:
return {**annotation, "type": "url_citation"}
if annotation_type == "document_citation":
new_ann: dict[str, Any] = {"type": "file_citation"}
if "title" in annotation:
new_ann["filename"] = annotation["title"]
for fld in ("file_id", "index"):
if fld in annotation:
new_ann[fld] = annotation[fld]
if "file_id" in annotation:
new_ann["file_id"] = annotation["file_id"]
if "file_index" in annotation:
new_ann["index"] = annotation["file_index"]
return new_ann
elif annotation_type == "non_standard_annotation":
elif annotation["type"] == "non_standard_annotation":
return annotation["value"]
else:
@ -528,7 +539,10 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
elif "reasoning" not in block and "summary" not in block:
# {"type": "reasoning", "id": "rs_..."}
oai_format = {**block, "summary": []}
# Update key order
oai_format["type"] = oai_format.pop("type", "reasoning")
if "encrypted_content" in oai_format:
oai_format["encrypted_content"] = oai_format.pop("encrypted_content")
yield oai_format
i += 1
continue
@ -594,13 +608,11 @@ def _consolidate_calls(
# If this really is the matching “result” collapse
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
if call_name == "web_search_call":
collapsed = {
"id": current["id"],
"status": current["status"],
"type": "web_search_call",
}
collapsed = {"id": current["id"]}
if "action" in current:
collapsed["action"] = current["action"]
collapsed["status"] = current["status"]
collapsed["type"] = "web_search_call"
if call_name == "code_interpreter_call":
collapsed = {"id": current["id"]}
@ -621,51 +633,50 @@ def _consolidate_calls(
yield nxt
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
if not isinstance(message.content, list):
return message
def _convert_from_v1_to_responses(
content: list[types.ContentBlock], tool_calls: list[types.ToolCall]
) -> list[dict[str, Any]]:
new_content: list = []
for block in message.content:
if isinstance(block, dict):
block_type = block.get("type")
if block_type == "text" and "annotations" in block:
# Need a copy because were changing the annotations list
new_block = dict(block)
new_block["annotations"] = [
_convert_annotation_from_v1(a) for a in block["annotations"]
for block in content:
if block["type"] == "text" and "annotations" in block:
# Need a copy because were changing the annotations list
new_block = dict(block)
new_block["annotations"] = [
_convert_annotation_from_v1(a) for a in block["annotations"]
]
new_content.append(new_block)
elif block["type"] == "tool_call":
new_block = {"type": "function_call", "call_id": block["id"]}
if "item_id" in block:
new_block["id"] = block["item_id"] # type: ignore[typeddict-item]
if "name" in block and "arguments" in block:
new_block["name"] = block["name"]
new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item]
else:
matching_tool_calls = [
call for call in tool_calls if call["id"] == block["id"]
]
new_content.append(new_block)
elif block_type == "tool_call":
new_block = {"type": "function_call", "call_id": block["id"]}
if "item_id" in block:
new_block["id"] = block["item_id"]
if "name" in block and "arguments" in block:
new_block["name"] = block["name"]
new_block["arguments"] = block["arguments"]
else:
tool_call = next(
call for call in message.tool_calls if call["id"] == block["id"]
)
if matching_tool_calls:
tool_call = matching_tool_calls[0]
if "name" not in block:
new_block["name"] = tool_call["name"]
if "arguments" not in block:
new_block["arguments"] = json.dumps(tool_call["args"])
new_content.append(new_block)
elif (
is_data_content_block(block)
and block["type"] == "image"
and "base64" in block
):
new_block = {"type": "image_generation_call", "result": block["base64"]}
for extra_key in ("id", "status"):
if extra_key in block:
new_block[extra_key] = block[extra_key]
new_content.append(new_block)
elif block_type == "non_standard" and "value" in block:
new_content.append(block["value"])
else:
new_content.append(block)
new_content.append(new_block)
elif (
is_data_content_block(cast(dict, block))
and block["type"] == "image"
and "base64" in block
and isinstance(block.get("id"), str)
and block["id"].startswith("ig_")
):
new_block = {"type": "image_generation_call", "result": block["base64"]}
for extra_key in ("id", "status"):
if extra_key in block:
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
new_content.append(new_block)
elif block["type"] == "non_standard" and "value" in block:
new_content.append(block["value"])
else:
new_content.append(block)
@ -679,4 +690,4 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
)
)
return message.model_copy(update={"content": new_content})
return new_content

File diff suppressed because it is too large Load Diff

View File

@ -56,6 +56,8 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"
disable_error_code = ["typeddict-unknown-key"]
[[tool.mypy.overrides]]
module = "transformers"
ignore_missing_imports = true

View File

@ -14,16 +14,24 @@ from langchain_core.messages import (
HumanMessage,
MessageLikeRepresentation,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from pydantic import BaseModel
from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, ChatOpenAIV1
MODEL_NAME = "gpt-4o-mini"
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
def _check_response(response: Optional[BaseMessage], output_version) -> None:
if output_version == "v1":
assert isinstance(response, AIMessageV1) or isinstance(
response, AIMessageChunkV1
)
else:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
for block in response.content:
assert isinstance(block, dict)
@ -41,7 +49,10 @@ def _check_response(response: Optional[BaseMessage]) -> None:
for key in ["end_index", "start_index", "title", "type", "url"]
)
text_content = response.text()
if output_version == "v1":
text_content = response.text
else:
text_content = response.text()
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
@ -56,22 +67,34 @@ def _check_response(response: Optional[BaseMessage]) -> None:
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
if output_version == "v1":
llm = ChatOpenAIV1(model=MODEL_NAME)
else:
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(first_response)
_check_response(first_response, output_version)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full)
if isinstance(llm, ChatOpenAIV1):
full: Optional[AIMessageChunkV1] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
else:
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full, output_version)
# Use OpenAI's stateful API
response = llm.invoke(
@ -79,38 +102,26 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
tools=[{"type": "web_search_preview"}],
previous_response_id=first_response.response_metadata["id"],
)
_check_response(response)
_check_response(response, output_version)
# Manually pass in chat history
response = llm.invoke(
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What was a positive news story from today?",
}
],
},
{"role": "user", "content": "What was a positive news story from today?"},
first_response,
{
"role": "user",
"content": [{"type": "text", "text": "what about a negative one"}],
},
{"role": "user", "content": "what about a negative one"},
],
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
_check_response(response, output_version)
# Bind tool
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
"What was a positive news story from today?"
)
_check_response(response)
_check_response(response, output_version)
for msg in [first_response, full, response]:
assert isinstance(msg, AIMessage)
block_types = [block["type"] for block in msg.content] # type: ignore[index]
if output_version == "responses/v1":
assert block_types == ["web_search_call", "text"]
@ -125,7 +136,7 @@ async def test_web_search_async() -> None:
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
_check_response(response, "v0")
assert response.response_metadata["status"]
# Test streaming
@ -137,7 +148,7 @@ async def test_web_search_async() -> None:
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
_check_response(full, "v0")
for msg in [response, full]:
assert msg.additional_kwargs["tool_outputs"]
@ -148,8 +159,8 @@ async def test_web_search_async() -> None:
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_function_calling(output_version: Literal["v0", "responses/v1"]) -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
@ -170,7 +181,33 @@ def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
response = bound_llm.invoke("What was a positive news story from today?")
_check_response(response)
_check_response(response, output_version)
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
@pytest.mark.vcr
def test_function_calling_v1() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
llm = ChatOpenAIV1(model=MODEL_NAME)
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
ai_msg = bound_llm.invoke("whats 5 * 4")
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "multiply"
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
full: Any = None
for chunk in bound_llm.stream("whats 5 * 4"):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert len(full.tool_calls) == 1
assert full.tool_calls[0]["name"] == "multiply"
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
response = bound_llm.invoke("What was a positive news story from today?")
_check_response(response, "v1")
class Foo(BaseModel):
@ -183,10 +220,8 @@ class FooDict(TypedDict):
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_parsed_pydantic_schema(
output_version: Literal["v0", "responses/v1", "v1"],
) -> None:
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_parsed_pydantic_schema(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI(
model=MODEL_NAME, use_responses_api=True, output_version=output_version
)
@ -206,6 +241,28 @@ def test_parsed_pydantic_schema(
assert parsed.response
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
@pytest.mark.vcr
def test_parsed_pydantic_schema_v1() -> None:
llm = ChatOpenAIV1(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text))
assert parsed == response.parsed
assert parsed.response
# Test stream
full: Optional[AIMessageChunkV1] = None
chunks = []
for chunk in llm.stream("how are ya", response_format=Foo):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
chunks.append(chunk)
assert isinstance(full, AIMessageChunkV1)
parsed = Foo(**json.loads(full.text))
assert parsed == full.parsed
assert parsed.response
async def test_parsed_pydantic_schema_async() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = await llm.ainvoke("how are ya", response_format=Foo)
@ -311,8 +368,8 @@ def test_function_calling_and_structured_output() -> None:
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)
@ -337,6 +394,26 @@ def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
assert block_types == ["reasoning", "text"]
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr
def test_reasoning_v1() -> None:
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
response = llm.invoke("Hello", reasoning={"effort": "low"})
assert isinstance(response, AIMessageV1)
# Test init params + streaming
llm = ChatOpenAIV1(model="o4-mini", reasoning={"effort": "low"})
full: Optional[AIMessageChunkV1] = None
for chunk in llm.stream("Hello"):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
for msg in [response, full]:
block_types = [block["type"] for block in msg.content]
assert block_types == ["reasoning", "text"]
def test_stateful_api() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are you, my name is Bobo")
@ -380,14 +457,14 @@ def test_file_search() -> None:
input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
response = llm.invoke([input_message], tools=[tool])
_check_response(response)
_check_response(response, "v0")
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream([input_message], tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
_check_response(full, "v0")
next_message = {"role": "user", "content": "Thank you."}
_ = llm.invoke([input_message, full, next_message])
@ -395,9 +472,9 @@ def test_file_search() -> None:
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_stream_reasoning_summary(
output_version: Literal["v0", "responses/v1", "v1"],
output_version: Literal["v0", "responses/v1"],
) -> None:
llm = ChatOpenAI(
model="o4-mini",
@ -424,7 +501,8 @@ def test_stream_reasoning_summary(
assert isinstance(block["type"], str)
assert isinstance(block["text"], str)
assert block["text"]
elif output_version == "responses/v1":
else:
# output_version == "responses/v1"
reasoning = next(
block
for block in response_1.content
@ -438,18 +516,6 @@ def test_stream_reasoning_summary(
assert isinstance(block["type"], str)
assert isinstance(block["text"], str)
assert block["text"]
else:
# v1
total_reasoning_blocks = 0
for block in response_1.content:
if block["type"] == "reasoning":
total_reasoning_blocks += 1
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
assert isinstance(block["reasoning"], str)
assert isinstance(block["index"], int)
assert (
total_reasoning_blocks > 1
) # This query typically generates multiple reasoning blocks
# Check we can pass back summaries
message_2 = {"role": "user", "content": "Thank you."}
@ -457,10 +523,45 @@ def test_stream_reasoning_summary(
assert isinstance(response_2, AIMessage)
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr
def test_stream_reasoning_summary_v1() -> None:
llm = ChatOpenAIV1(
model="o4-mini",
# Routes to Responses API if `reasoning` is set.
reasoning={"effort": "medium", "summary": "auto"},
)
message_1 = {
"role": "user",
"content": "What was the third tallest buliding in the year 2000?",
}
response_1: Optional[AIMessageChunkV1] = None
for chunk in llm.stream([message_1]):
assert isinstance(chunk, AIMessageChunkV1)
response_1 = chunk if response_1 is None else response_1 + chunk
assert isinstance(response_1, AIMessageChunkV1)
total_reasoning_blocks = 0
for block in response_1.content:
if block["type"] == "reasoning":
total_reasoning_blocks += 1
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
assert isinstance(block["reasoning"], str)
assert isinstance(block["index"], int)
assert (
total_reasoning_blocks > 1
) # This query typically generates multiple reasoning blocks
# Check we can pass back summaries
message_2 = {"role": "user", "content": "Thank you."}
response_2 = llm.invoke([message_1, response_1, message_2])
assert isinstance(response_2, AIMessageV1)
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_code_interpreter(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)
@ -473,33 +574,20 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
}
response = llm_with_tools.invoke([input_message])
assert isinstance(response, AIMessage)
_check_response(response)
_check_response(response, output_version)
if output_version == "v0":
tool_outputs = [
item
for item in response.additional_kwargs["tool_outputs"]
if item["type"] == "code_interpreter_call"
]
elif output_version == "responses/v1":
tool_outputs = [
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
else:
# v1
# responses/v1
tool_outputs = [
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
code_interpreter_result = next(
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
)
assert tool_outputs
assert code_interpreter_result
assert len(tool_outputs) == 1
# Test streaming
@ -520,25 +608,65 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
for item in response.additional_kwargs["tool_outputs"]
if item["type"] == "code_interpreter_call"
]
elif output_version == "responses/v1":
else:
# responses/v1
tool_outputs = [
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
else:
code_interpreter_call = next(
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
)
code_interpreter_result = next(
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
)
assert code_interpreter_call
assert code_interpreter_result
assert tool_outputs
# Test we can pass back in
next_message = {"role": "user", "content": "Please add more comments to the code."}
_ = llm_with_tools.invoke([input_message, full, next_message])
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
@pytest.mark.vcr
def test_code_interpreter_v1() -> None:
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": {"type": "auto"}}]
)
input_message = {
"role": "user",
"content": "Write and run code to answer the question: what is 3^3?",
}
response = llm_with_tools.invoke([input_message])
assert isinstance(response, AIMessageV1)
_check_response(response, "v1")
tool_outputs = [
item for item in response.content if item["type"] == "code_interpreter_call"
]
code_interpreter_result = next(
item for item in response.content if item["type"] == "code_interpreter_result"
)
assert tool_outputs
assert code_interpreter_result
assert len(tool_outputs) == 1
# Test streaming
# Use same container
container_id = tool_outputs[0]["container_id"]
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": container_id}]
)
full: Optional[AIMessageChunkV1] = None
for chunk in llm_with_tools.stream([input_message]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
code_interpreter_call = next(
item for item in full.content if item["type"] == "code_interpreter_call"
)
code_interpreter_result = next(
item for item in full.content if item["type"] == "code_interpreter_result"
)
assert code_interpreter_call
assert code_interpreter_result
assert tool_outputs
# Test we can pass back in
@ -634,9 +762,59 @@ def test_mcp_builtin_zdr() -> None:
_ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz")
@pytest.mark.vcr
def test_mcp_builtin_zdr_v1() -> None:
llm = ChatOpenAIV1(
model="o4-mini", store=False, include=["reasoning.encrypted_content"]
)
llm_with_tools = llm.bind_tools(
[
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
}
]
)
input_message = {
"role": "user",
"content": (
"What transport protocols does the 2025-03-26 version of the MCP spec "
"support?"
),
}
full: Optional[AIMessageChunkV1] = None
for chunk in llm_with_tools.stream([input_message]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert all(isinstance(block, dict) for block in full.content)
approval_message = HumanMessageV1(
[
{
"type": "non_standard",
"value": {
"type": "mcp_approval_response",
"approve": True,
"approval_request_id": block["value"]["id"], # type: ignore[index]
},
}
for block in full.content
if block["type"] == "non_standard"
and block["value"]["type"] == "mcp_approval_request" # type: ignore[index]
]
)
_ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_streaming(output_version: str) -> None:
"""Test image generation streaming."""
llm = ChatOpenAI(
@ -710,9 +888,52 @@ def test_image_generation_streaming(output_version: str) -> None:
assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr
def test_image_generation_streaming_v1() -> None:
"""Test image generation streaming."""
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
tool = {
"type": "image_generation",
"quality": "low",
"output_format": "jpeg",
"output_compression": 100,
"size": "1024x1024",
}
expected_keys = {
# Standard
"type",
"base64",
"mime_type",
"id",
"index",
# OpenAI-specific
"background",
"output_format",
"quality",
"revised_prompt",
"size",
"status",
}
full: Optional[AIMessageChunkV1] = None
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
complete_ai_message = cast(AIMessageChunkV1, full)
tool_output = next(
block
for block in complete_ai_message.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(expected_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_multi_turn(output_version: str) -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
@ -735,7 +956,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
]
ai_message = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message, AIMessage)
_check_response(ai_message)
_check_response(ai_message, output_version)
expected_keys = {
"id",
@ -801,7 +1022,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
ai_message2 = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message2, AIMessage)
_check_response(ai_message2)
_check_response(ai_message2, output_version)
if output_version == "v0":
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
@ -821,3 +1042,76 @@ def test_image_generation_multi_turn(output_version: str) -> None:
if isinstance(block, dict) and block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr
def test_image_generation_multi_turn_v1() -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
# Test invocation
tool = {
"type": "image_generation",
"quality": "low",
"output_format": "jpeg",
"output_compression": 100,
"size": "1024x1024",
}
llm_with_tools = llm.bind_tools([tool])
chat_history: list[MessageLikeRepresentation] = [
{"role": "user", "content": "Draw a random short word in green font."}
]
ai_message = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message, AIMessageV1)
_check_response(ai_message, "v1")
expected_keys = {
# Standard
"type",
"base64",
"mime_type",
"id",
# OpenAI-specific
"background",
"output_format",
"quality",
"revised_prompt",
"size",
"status",
}
standard_keys = {"type", "base64", "id", "status"}
tool_output = next(
block
for block in ai_message.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
chat_history.extend(
[
# AI message with tool output
ai_message,
# New request
{
"role": "user",
"content": (
"Now, change the font to blue. Keep the word and everything else "
"the same."
),
},
]
)
ai_message2 = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message2, AIMessageV1)
_check_response(ai_message2, "v1")
tool_output = next(
block
for block in ai_message2.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(expected_keys).issubset(tool_output.keys())