feat(openai): v1 message format support (#32296)

This commit is contained in:
ccurme 2025-07-28 19:42:26 -03:00 committed by GitHub
parent 7166adce1f
commit c15e55b33c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4682 additions and 299 deletions

View File

@ -2,8 +2,7 @@
from __future__ import annotations from __future__ import annotations
import warnings from abc import ABC
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from functools import cache from functools import cache
from typing import ( from typing import (
@ -26,7 +25,6 @@ from langchain_core.messages import (
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
MessageLikeRepresentation, MessageLikeRepresentation,
get_buffer_string,
) )
from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
@ -166,7 +164,6 @@ class BaseLanguageModel(
list[AnyMessage], list[AnyMessage],
] ]
@abstractmethod
def generate_prompt( def generate_prompt(
self, self,
prompts: list[PromptValue], prompts: list[PromptValue],
@ -201,7 +198,6 @@ class BaseLanguageModel(
prompt and additional model provider-specific output. prompt and additional model provider-specific output.
""" """
@abstractmethod
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: list[PromptValue], prompts: list[PromptValue],
@ -245,7 +241,6 @@ class BaseLanguageModel(
raise NotImplementedError raise NotImplementedError
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict( def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
@ -266,7 +261,6 @@ class BaseLanguageModel(
""" """
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
def predict_messages( def predict_messages(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -291,7 +285,6 @@ class BaseLanguageModel(
""" """
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict( async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
@ -312,7 +305,6 @@ class BaseLanguageModel(
""" """
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@abstractmethod
async def apredict_messages( async def apredict_messages(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -368,33 +360,6 @@ class BaseLanguageModel(
""" """
return len(self.get_token_ids(text)) return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod @classmethod
def _all_required_field_names(cls) -> set: def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility. """DEPRECATED: Kept for backwards compatibility.

View File

@ -55,12 +55,11 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
convert_to_messages, convert_to_messages,
convert_to_openai_image_block, convert_to_openai_image_block,
get_buffer_string,
is_data_content_block, is_data_content_block,
message_chunk_to_message, message_chunk_to_message,
) )
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX from langchain_core.messages.ai import _LC_ID_PREFIX
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatGenerationChunk, ChatGenerationChunk,
@ -222,23 +221,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
return ls_structured_output_format_dict return ls_structured_output_format_dict
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for chat models. """Base class for chat models.
@ -1370,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
return starter_dict return starter_dict
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def bind_tools( def bind_tools(
self, self,
tools: Sequence[ tools: Sequence[

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy import copy
import typing import typing
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from operator import itemgetter from operator import itemgetter
@ -38,11 +39,14 @@ from langchain_core.language_models.base import (
) )
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage,
convert_to_openai_image_block, convert_to_openai_image_block,
get_buffer_string,
is_data_content_block, is_data_content_block,
) )
from langchain_core.messages.utils import convert_to_messages_v1 from langchain_core.messages.utils import (
_convert_from_v1_message,
convert_to_messages_v1,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
@ -735,7 +739,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
*, *,
tool_choice: Optional[Union[str]] = None, tool_choice: Optional[Union[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, AIMessageV1]:
"""Bind tools to the model. """Bind tools to the model.
Args: Args:
@ -899,6 +903,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser return llm | output_parser
def get_num_tokens_from_messages(
self,
messages: list[MessageV1],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
messages_v0 = [_convert_from_v1_message(message) for message in messages]
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages_v0)
def _gen_info_and_msg_metadata( def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk], generation: Union[ChatGeneration, ChatGenerationChunk],

View File

@ -706,6 +706,7 @@ ToolContentBlock = Union[
ContentBlock = Union[ ContentBlock = Union[
TextContentBlock, TextContentBlock,
ToolCall, ToolCall,
InvalidToolCall,
ReasoningContentBlock, ReasoningContentBlock,
NonStandardContentBlock, NonStandardContentBlock,
DataContentBlock, DataContentBlock,

View File

@ -384,38 +384,37 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
Returns: Returns:
BaseMessage: Converted message instance. BaseMessage: Converted message instance.
""" """
# type ignores here are because AIMessageV1.content is a list of dicts. content = cast("Union[str, list[str | dict]]", message.content)
# AIMessageV0.content expects str or list[str | dict].
if isinstance(message, AIMessageV1): if isinstance(message, AIMessageV1):
return AIMessage( return AIMessage(
content=message.content, # type: ignore[arg-type] content=content,
id=message.id, id=message.id,
name=message.name, name=message.name,
tool_calls=message.tool_calls, tool_calls=message.tool_calls,
response_metadata=message.response_metadata, response_metadata=cast("dict", message.response_metadata),
) )
if isinstance(message, AIMessageChunkV1): if isinstance(message, AIMessageChunkV1):
return AIMessageChunk( return AIMessageChunk(
content=message.content, # type: ignore[arg-type] content=content,
id=message.id, id=message.id,
name=message.name, name=message.name,
tool_call_chunks=message.tool_call_chunks, tool_call_chunks=message.tool_call_chunks,
response_metadata=message.response_metadata, response_metadata=cast("dict", message.response_metadata),
) )
if isinstance(message, HumanMessageV1): if isinstance(message, HumanMessageV1):
return HumanMessage( return HumanMessage(
content=message.content, # type: ignore[arg-type] content=content,
id=message.id, id=message.id,
name=message.name, name=message.name,
) )
if isinstance(message, SystemMessageV1): if isinstance(message, SystemMessageV1):
return SystemMessage( return SystemMessage(
content=message.content, # type: ignore[arg-type] content=content,
id=message.id, id=message.id,
) )
if isinstance(message, ToolMessageV1): if isinstance(message, ToolMessageV1):
return ToolMessage( return ToolMessage(
content=message.content, # type: ignore[arg-type] content=content,
id=message.id, id=message.id,
) )
message = f"Unsupported message type: {type(message)}" message = f"Unsupported message type: {type(message)}"
@ -501,7 +500,10 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
ValueError: if the message dict does not contain the required keys. ValueError: if the message dict does not contain the required keys.
""" """
if isinstance(message, MessageV1Types): if isinstance(message, MessageV1Types):
message_ = message if isinstance(message, AIMessageChunkV1):
message_ = message.to_message()
else:
message_ = message
elif isinstance(message, str): elif isinstance(message, str):
message_ = _create_message_from_message_type_v1("human", message) message_ = _create_message_from_message_type_v1("human", message)
elif isinstance(message, Sequence) and len(message) == 2: elif isinstance(message, Sequence) and len(message) == 2:

View File

@ -5,6 +5,8 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
from pydantic import BaseModel
import langchain_core.messages.content_blocks as types import langchain_core.messages.content_blocks as types
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
from langchain_core.messages.base import merge_content from langchain_core.messages.base import merge_content
@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str:
return id_val or str(uuid.uuid4()) return id_val or str(uuid.uuid4())
class Provider(TypedDict): class ResponseMetadata(TypedDict, total=False):
"""Information about the provider that generated the message. """Metadata about the response from the AI provider.
Contains metadata about the AI provider and model used to generate content. Contains additional information returned by the provider, such as
response headers, service tiers, log probabilities, system fingerprints, etc.
Attributes: Extra keys are permitted from what is typed here.
name: Name and version of the provider that created the content block.
model_name: Name of the model that generated the content block.
""" """
name: str model_provider: str
"""Name and version of the provider that created the content block.""" """Name and version of the provider that created the message (e.g., openai)."""
model_name: str model_name: str
"""Name of the model that generated the content block.""" """Name of the model that generated the message."""
@dataclass @dataclass
@ -91,21 +93,29 @@ class AIMessage:
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts.""" """If provided, usage metadata for a message, such as token counts."""
response_metadata: dict = field(default_factory=dict) response_metadata: ResponseMetadata = field(
default_factory=lambda: ResponseMetadata()
)
"""Metadata about the response. """Metadata about the response.
This field should include non-standard data returned by the provider, such as This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities. response headers, service tiers, or log probabilities.
""" """
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
def __init__( def __init__(
self, self,
content: Union[str, list[types.ContentBlock]], content: Union[str, list[types.ContentBlock]],
id: Optional[str] = None, id: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
lc_version: str = "v1", lc_version: str = "v1",
response_metadata: Optional[dict] = None, response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None, usage_metadata: Optional[UsageMetadata] = None,
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
): ):
"""Initialize an AI message. """Initialize an AI message.
@ -116,6 +126,11 @@ class AIMessage:
lc_version: Encoding version for the message. lc_version: Encoding version for the message.
response_metadata: Optional metadata about the response. response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage. usage_metadata: Optional metadata about token usage.
tool_calls: Optional list of tool calls made by the AI. Tool calls should
generally be included in message content. If passed on init, they will
be added to the content list.
invalid_tool_calls: Optional list of tool calls that failed validation.
parsed: Optional auto-parsed message contents, if applicable.
""" """
if isinstance(content, str): if isinstance(content, str):
self.content = [{"type": "text", "text": content}] self.content = [{"type": "text", "text": content}]
@ -126,13 +141,27 @@ class AIMessage:
self.name = name self.name = name
self.lc_version = lc_version self.lc_version = lc_version
self.usage_metadata = usage_metadata self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None: if response_metadata is None:
self.response_metadata = {} self.response_metadata = {}
else: else:
self.response_metadata = response_metadata self.response_metadata = response_metadata
self._tool_calls: list[types.ToolCall] = [] # Add tool calls to content if provided on init
self._invalid_tool_calls: list[types.InvalidToolCall] = [] if tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if block["type"] == "tool_call" and "id" in block
}
for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
self._tool_calls = [
block for block in self.content if block["type"] == "tool_call"
]
self.invalid_tool_calls = invalid_tool_calls or []
@property @property
def text(self) -> Optional[str]: def text(self) -> Optional[str]:
@ -150,7 +179,7 @@ class AIMessage:
tool_calls = [block for block in self.content if block["type"] == "tool_call"] tool_calls = [block for block in self.content if block["type"] == "tool_call"]
if tool_calls: if tool_calls:
self._tool_calls = tool_calls self._tool_calls = tool_calls
return self._tool_calls return [block for block in self.content if block["type"] == "tool_call"]
@tool_calls.setter @tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None: def tool_calls(self, value: list[types.ToolCall]) -> None:
@ -202,13 +231,16 @@ class AIMessageChunk:
These data represent incremental usage statistics, as opposed to a running total. These data represent incremental usage statistics, as opposed to a running total.
""" """
response_metadata: dict = field(init=False) response_metadata: ResponseMetadata = field(init=False)
"""Metadata about the response chunk. """Metadata about the response chunk.
This field should include non-standard data returned by the provider, such as This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities. response headers, service tiers, or log probabilities.
""" """
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
tool_call_chunks: list[types.ToolCallChunk] = field(init=False) tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
def __init__( def __init__(
@ -217,9 +249,10 @@ class AIMessageChunk:
id: Optional[str] = None, id: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
lc_version: str = "v1", lc_version: str = "v1",
response_metadata: Optional[dict] = None, response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None, usage_metadata: Optional[UsageMetadata] = None,
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None, tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
): ):
"""Initialize an AI message. """Initialize an AI message.
@ -231,6 +264,7 @@ class AIMessageChunk:
response_metadata: Optional metadata about the response. response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage. usage_metadata: Optional metadata about token usage.
tool_call_chunks: Optional list of partial tool call data. tool_call_chunks: Optional list of partial tool call data.
parsed: Optional auto-parsed message contents, if applicable.
""" """
if isinstance(content, str): if isinstance(content, str):
self.content = [{"type": "text", "text": content, "index": 0}] self.content = [{"type": "text", "text": content, "index": 0}]
@ -241,6 +275,7 @@ class AIMessageChunk:
self.name = name self.name = name
self.lc_version = lc_version self.lc_version = lc_version
self.usage_metadata = usage_metadata self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None: if response_metadata is None:
self.response_metadata = {} self.response_metadata = {}
else: else:
@ -251,7 +286,7 @@ class AIMessageChunk:
self.tool_call_chunks = tool_call_chunks self.tool_call_chunks = tool_call_chunks
self._tool_calls: list[types.ToolCall] = [] self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = [] self.invalid_tool_calls: list[types.InvalidToolCall] = []
self._init_tool_calls() self._init_tool_calls()
def _init_tool_calls(self) -> None: def _init_tool_calls(self) -> None:
@ -264,7 +299,7 @@ class AIMessageChunk:
ValueError: If the tool call chunks are malformed. ValueError: If the tool call chunks are malformed.
""" """
self._tool_calls = [] self._tool_calls = []
self._invalid_tool_calls = [] self.invalid_tool_calls = []
if not self.tool_call_chunks: if not self.tool_call_chunks:
if self._tool_calls: if self._tool_calls:
self.tool_call_chunks = [ self.tool_call_chunks = [
@ -276,14 +311,14 @@ class AIMessageChunk:
) )
for tc in self._tool_calls for tc in self._tool_calls
] ]
if self._invalid_tool_calls: if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend( tool_call_chunks.extend(
[ [
create_tool_call_chunk( create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None name=tc["name"], args=tc["args"], id=tc["id"], index=None
) )
for tc in self._invalid_tool_calls for tc in self.invalid_tool_calls
] ]
) )
self.tool_call_chunks = tool_call_chunks self.tool_call_chunks = tool_call_chunks
@ -294,9 +329,9 @@ class AIMessageChunk:
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
invalid_tool_calls.append( invalid_tool_calls.append(
create_invalid_tool_call( create_invalid_tool_call(
name=chunk["name"], name=chunk.get("name", ""),
args=chunk["args"], args=chunk.get("args", ""),
id=chunk["id"], id=chunk.get("id", ""),
error=None, error=None,
) )
) )
@ -307,9 +342,9 @@ class AIMessageChunk:
if isinstance(args_, dict): if isinstance(args_, dict):
tool_calls.append( tool_calls.append(
create_tool_call( create_tool_call(
name=chunk["name"] or "", name=chunk.get("name", ""),
args=args_, args=args_,
id=chunk["id"], id=chunk.get("id", ""),
) )
) )
else: else:
@ -317,7 +352,7 @@ class AIMessageChunk:
except Exception: except Exception:
add_chunk_to_invalid_tool_calls(chunk) add_chunk_to_invalid_tool_calls(chunk)
self._tool_calls = tool_calls self._tool_calls = tool_calls
self._invalid_tool_calls = invalid_tool_calls self.invalid_tool_calls = invalid_tool_calls
@property @property
def text(self) -> Optional[str]: def text(self) -> Optional[str]:
@ -361,6 +396,20 @@ class AIMessageChunk:
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk." error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
raise NotImplementedError(error_msg) raise NotImplementedError(error_msg)
def to_message(self) -> "AIMessage":
"""Convert this AIMessageChunk to an AIMessage."""
return AIMessage(
content=self.content,
id=self.id,
name=self.name,
lc_version=self.lc_version,
response_metadata=self.response_metadata,
usage_metadata=self.usage_metadata,
tool_calls=self.tool_calls,
invalid_tool_calls=self.invalid_tool_calls,
parsed=self.parsed,
)
def add_ai_message_chunks( def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk left: AIMessageChunk, *others: AIMessageChunk
@ -373,7 +422,8 @@ def add_ai_message_chunks(
*(cast("list[str | dict[Any, Any]]", o.content) for o in others), *(cast("list[str | dict[Any, Any]]", o.content) for o in others),
) )
response_metadata = merge_dicts( response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others) cast("dict", left.response_metadata),
*(cast("dict", o.response_metadata) for o in others),
) )
# Merge tool call chunks # Merge tool call chunks
@ -400,6 +450,15 @@ def add_ai_message_chunks(
else: else:
usage_metadata = None usage_metadata = None
# Parsed
# 'parsed' always represents an aggregation not an incremental value, so the last
# non-null value is kept.
parsed = None
for m in reversed([left, *others]):
if m.parsed is not None:
parsed = m.parsed
break
chunk_id = None chunk_id = None
candidates = [left.id] + [o.id for o in others] candidates = [left.id] + [o.id for o in others]
# first pass: pick the first non-run-* id # first pass: pick the first non-run-* id
@ -417,8 +476,9 @@ def add_ai_message_chunks(
return left.__class__( return left.__class__(
content=cast("list[types.ContentBlock]", content), content=cast("list[types.ContentBlock]", content),
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata, response_metadata=cast("ResponseMetadata", response_metadata),
usage_metadata=usage_metadata, usage_metadata=usage_metadata,
parsed=parsed,
id=chunk_id, id=chunk_id,
) )
@ -455,19 +515,25 @@ class HumanMessage:
""" """
def __init__( def __init__(
self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
): ):
"""Initialize a human message. """Initialize a human message.
Args: Args:
content: Message content as string or list of content blocks. content: Message content as string or list of content blocks.
id: Optional unique identifier for the message. id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
""" """
self.id = _ensure_id(id) self.id = _ensure_id(id)
if isinstance(content, str): if isinstance(content, str):
self.content = [{"type": "text", "text": content}] self.content = [{"type": "text", "text": content}]
else: else:
self.content = content self.content = content
self.name = name
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the message. """Extract all text content from the message.
@ -497,20 +563,47 @@ class SystemMessage:
content: list[types.ContentBlock] content: list[types.ContentBlock]
type: Literal["system"] = "system" type: Literal["system"] = "system"
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
custom_role: Optional[str] = None
"""If provided, a custom role for the system message.
Example: ``"developer"``.
Integration packages may use this field to assign the system message role if it
contains a recognized value.
"""
def __init__( def __init__(
self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
custom_role: Optional[str] = None,
name: Optional[str] = None,
): ):
"""Initialize a system message. """Initialize a human message.
Args: Args:
content: System instructions as string or list of content blocks. content: Message content as string or list of content blocks.
id: Optional unique identifier for the message. id: Optional unique identifier for the message.
custom_role: If provided, a custom role for the system message.
name: Optional human-readable name for the message.
""" """
self.id = _ensure_id(id) self.id = _ensure_id(id)
if isinstance(content, str): if isinstance(content, str):
self.content = [{"type": "text", "text": content}] self.content = [{"type": "text", "text": content}]
else: else:
self.content = content self.content = content
self.custom_role = custom_role
self.name = name
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the system message.""" """Extract all text content from the system message."""
@ -537,11 +630,51 @@ class ToolMessage:
id: str id: str
tool_call_id: str tool_call_id: str
content: list[dict[str, Any]] content: list[types.ContentBlock]
artifact: Optional[Any] = None # App-side payload not for the model artifact: Optional[Any] = None # App-side payload not for the model
name: Optional[str] = None
"""An optional name for the message.
This can be used to provide a human-readable name for the message.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
"""
status: Literal["success", "error"] = "success" status: Literal["success", "error"] = "success"
type: Literal["tool"] = "tool" type: Literal["tool"] = "tool"
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
tool_call_id: str,
*,
id: Optional[str] = None,
name: Optional[str] = None,
artifact: Optional[Any] = None,
status: Literal["success", "error"] = "success",
):
"""Initialize a human message.
Args:
content: Message content as string or list of content blocks.
tool_call_id: ID of the tool call this message responds to.
id: Optional unique identifier for the message.
name: Optional human-readable name for the message.
artifact: Optional app-side payload not intended for the model.
status: Execution status ("success" or "error").
"""
self.id = _ensure_id(id)
self.tool_call_id = tool_call_id
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
self.artifact = artifact
self.status = status
@property @property
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the tool message.""" """Extract all text content from the tool message."""

View File

@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError from pydantic import SkipValidation, ValidationError
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
from langchain_core.messages.tool import invalid_tool_call from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
@ -26,7 +26,7 @@ def parse_tool_call(
partial: bool = False, partial: bool = False,
strict: bool = False, strict: bool = False,
return_id: bool = True, return_id: bool = True,
) -> Optional[dict[str, Any]]: ) -> Optional[ToolCall]:
"""Parse a single tool call. """Parse a single tool call.
Args: Args:

View File

@ -8,17 +8,65 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, Union, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage,
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
SystemMessage,
ToolMessage,
get_buffer_string, get_buffer_string,
) )
from langchain_core.messages import content_blocks as types
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
def _convert_to_v1(message: BaseMessage) -> MessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
else:
content = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(cast("types.ContentBlock", block))
else:
pass
if isinstance(message, HumanMessage):
return HumanMessageV1(content=content)
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=cast("ResponseMetadata", message.response_metadata),
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
return SystemMessageV1(content=content)
if isinstance(message, ToolMessage):
return ToolMessageV1(
tool_call_id=message.tool_call_id,
content=content,
artifact=message.artifact,
)
error_message = f"Unsupported message type: {type(message)}"
raise TypeError(error_message)
class PromptValue(Serializable, ABC): class PromptValue(Serializable, ABC):
@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt value as string.""" """Return prompt value as string."""
@overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
@abstractmethod @abstractmethod
def to_messages(self) -> list[BaseMessage]: def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of Messages.""" """Return prompt as a list of Messages."""
@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
"""Return prompt as string.""" """Return prompt as string."""
return self.text return self.text
def to_messages(self) -> list[BaseMessage]: @overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as messages.""" """Return prompt as messages."""
if output_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)] return [HumanMessage(content=self.text)]
@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string.""" """Return prompt as string."""
return get_buffer_string(self.messages) return get_buffer_string(self.messages)
def to_messages(self) -> list[BaseMessage]: @overload
"""Return prompt as a list of messages.""" def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of messages.
Args:
output_version: The output version, either "v0" (default) or "v1".
"""
if output_version == "v1":
return [_convert_to_v1(m) for m in self.messages]
return list(self.messages) return list(self.messages)
@classmethod @classmethod
@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string.""" """Return prompt (image URL) as string."""
return self.image_url["url"] return self.image_url["url"]
def to_messages(self) -> list[BaseMessage]: @overload
def to_messages(
self, output_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, output_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt (image URL) as messages.""" """Return prompt (image URL) as messages."""
if output_version == "v1":
block: types.ImageContentBlock = {
"type": "image",
"url": self.image_url["url"],
}
if "detail" in self.image_url:
block["detail"] = self.image_url["detail"]
return [HumanMessageV1(content=[block])]
return [HumanMessage(content=[cast("dict", self.image_url)])] return [HumanMessage(content=[cast("dict", self.image_url)])]

View File

@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
strict = "True" strict = "True"
strict_bytes = "True" strict_bytes = "True"
enable_error_code = "deprecated" enable_error_code = "deprecated"
disable_error_code = ["typeddict-unknown-key"]
# TODO: activate for 'strict' checking # TODO: activate for 'strict' checking
disallow_any_generics = "False" disallow_any_generics = "False"

View File

@ -34,6 +34,7 @@ from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.utils._merge import merge_lists from langchain_core.utils._merge import merge_lists
@ -197,7 +198,7 @@ def test_message_chunks() -> None:
assert (meaningful_id + default_id).id == "msg_def456" assert (meaningful_id + default_id).id == "msg_def456"
def test_message_chunks_v2() -> None: def test_message_chunks_v1() -> None:
left = AIMessageChunkV1("foo ", id="abc") left = AIMessageChunkV1("foo ", id="abc")
right = AIMessageChunkV1("bar") right = AIMessageChunkV1("bar")
expected = AIMessageChunkV1("foo bar", id="abc") expected = AIMessageChunkV1("foo bar", id="abc")
@ -230,7 +231,19 @@ def test_message_chunks_v2() -> None:
) )
], ],
) )
assert one + two + three == expected result = one + two + three
assert result == expected
assert result.to_message() == AIMessageV1(
content=[
{
"name": "tool1",
"args": {"arg1": "value}"},
"id": "1",
"type": "tool_call",
}
]
)
assert ( assert (
AIMessageChunkV1( AIMessageChunkV1(
@ -1326,6 +1339,7 @@ def test_known_block_types() -> None:
"text", "text",
"text-plain", "text-plain",
"tool_call", "tool_call",
"invalid_tool_call",
"reasoning", "reasoning",
"non_standard", "non_standard",
"image", "image",

View File

@ -1,10 +1,11 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAIV1
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI from langchain_openai.llms import AzureOpenAI, OpenAI
__all__ = [ __all__ = [
"OpenAI", "OpenAI",
"ChatOpenAI", "ChatOpenAI",
"ChatOpenAIV1",
"OpenAIEmbeddings", "OpenAIEmbeddings",
"AzureOpenAI", "AzureOpenAI",
"AzureChatOpenAI", "AzureChatOpenAI",

View File

@ -1,4 +1,5 @@
from langchain_openai.chat_models.azure import AzureChatOpenAI from langchain_openai.chat_models.azure import AzureChatOpenAI
from langchain_openai.chat_models.base import ChatOpenAI from langchain_openai.chat_models.base import ChatOpenAI
from langchain_openai.chat_models.base_v1 import ChatOpenAI as ChatOpenAIV1
__all__ = ["ChatOpenAI", "AzureChatOpenAI"] __all__ = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAIV1"]

View File

@ -66,11 +66,14 @@ For backwards compatibility, this module provides functions to convert between t
formats. The functions are used internally by ChatOpenAI. formats. The functions are used internally by ChatOpenAI.
""" # noqa: E501 """ # noqa: E501
import copy
import json import json
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from typing import Any, Literal, Union, cast from typing import Any, Literal, Optional, Union, cast
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
from langchain_core.messages import content_blocks as types
from langchain_core.messages.v1 import AIMessage as AIMessageV1
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
@ -289,25 +292,21 @@ def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessa
return cast(AIMessageChunk, result) return cast(AIMessageChunk, result)
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage: def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1:
"""Convert a v1 message to the Chat Completions format.""" """Convert a v1 message to the Chat Completions format."""
if isinstance(message.content, list): new_content: list[types.ContentBlock] = []
new_content: list = [] for block in message.content:
for block in message.content: if block["type"] == "text":
if isinstance(block, dict): # Strip annotations
block_type = block.get("type") new_content.append({"type": "text", "text": block["text"]})
if block_type == "text": elif block["type"] in ("reasoning", "tool_call"):
# Strip annotations pass
new_content.append({"type": "text", "text": block["text"]}) else:
elif block_type in ("reasoning", "tool_call"): new_content.append(block)
pass new_message = copy.copy(message)
else: new_message.content = new_content
new_content.append(block)
else:
new_content.append(block)
return message.model_copy(update={"content": new_content})
return message return new_message
# v1 / Responses # v1 / Responses
@ -319,17 +318,18 @@ def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
for field in ("end_index", "start_index", "title"): for field in ("end_index", "start_index", "title"):
if field in annotation: if field in annotation:
url_citation[field] = annotation[field] url_citation[field] = annotation[field]
url_citation["type"] = "url_citation" url_citation["type"] = "citation"
url_citation["url"] = annotation["url"] url_citation["url"] = annotation["url"]
return url_citation return url_citation
elif annotation_type == "file_citation": elif annotation_type == "file_citation":
document_citation = {"type": "document_citation"} document_citation = {"type": "citation"}
if "filename" in annotation: if "filename" in annotation:
document_citation["title"] = annotation["filename"] document_citation["title"] = annotation["filename"]
for field in ("file_id", "index"): # OpenAI-specific if "file_id" in annotation:
if field in annotation: document_citation["file_id"] = annotation["file_id"]
document_citation[field] = annotation[field] if "index" in annotation:
document_citation["file_index"] = annotation["index"]
return document_citation return document_citation
# TODO: standardise container_file_citation? # TODO: standardise container_file_citation?
@ -367,13 +367,15 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
yield new_block yield new_block
def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: def _convert_to_v1_from_responses(
content: list[dict[str, Any]],
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
) -> list[types.ContentBlock]:
"""Mutate a Responses message to v1 format.""" """Mutate a Responses message to v1 format."""
if not isinstance(message.content, list):
return message
def _iter_blocks() -> Iterable[dict[str, Any]]: def _iter_blocks() -> Iterable[dict[str, Any]]:
for block in message.content: for block in content:
if not isinstance(block, dict): if not isinstance(block, dict):
continue continue
block_type = block.get("type") block_type = block.get("type")
@ -409,13 +411,24 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
yield new_block yield new_block
elif block_type == "function_call": elif block_type == "function_call":
new_block = {"type": "tool_call", "id": block.get("call_id", "")} new_block = None
if "id" in block: call_id = block.get("call_id", "")
new_block["item_id"] = block["id"] if call_id:
for extra_key in ("arguments", "name", "index"): for tool_call in tool_calls or []:
if extra_key in block: if tool_call.get("id") == call_id:
new_block[extra_key] = block[extra_key] new_block = tool_call.copy()
yield new_block break
else:
for invalid_tool_call in invalid_tool_calls or []:
if invalid_tool_call.get("id") == call_id:
new_block = invalid_tool_call.copy()
break
if new_block:
if "id" in block:
new_block["item_id"] = block["id"]
if "index" in block:
new_block["index"] = block["index"]
yield new_block
elif block_type == "web_search_call": elif block_type == "web_search_call":
web_search_call = {"type": "web_search_call", "id": block["id"]} web_search_call = {"type": "web_search_call", "id": block["id"]}
@ -485,28 +498,26 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
new_block["index"] = new_block["value"].pop("index") new_block["index"] = new_block["value"].pop("index")
yield new_block yield new_block
# Replace the list with the fully converted one return list(_iter_blocks())
message.content = list(_iter_blocks())
return message
def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]: def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
annotation_type = annotation.get("type") if annotation["type"] == "citation":
if "url" in annotation:
return {**annotation, "type": "url_citation"}
if annotation_type == "document_citation":
new_ann: dict[str, Any] = {"type": "file_citation"} new_ann: dict[str, Any] = {"type": "file_citation"}
if "title" in annotation: if "title" in annotation:
new_ann["filename"] = annotation["title"] new_ann["filename"] = annotation["title"]
if "file_id" in annotation:
for fld in ("file_id", "index"): new_ann["file_id"] = annotation["file_id"]
if fld in annotation: if "file_index" in annotation:
new_ann[fld] = annotation[fld] new_ann["index"] = annotation["file_index"]
return new_ann return new_ann
elif annotation_type == "non_standard_annotation": elif annotation["type"] == "non_standard_annotation":
return annotation["value"] return annotation["value"]
else: else:
@ -528,7 +539,10 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
elif "reasoning" not in block and "summary" not in block: elif "reasoning" not in block and "summary" not in block:
# {"type": "reasoning", "id": "rs_..."} # {"type": "reasoning", "id": "rs_..."}
oai_format = {**block, "summary": []} oai_format = {**block, "summary": []}
# Update key order
oai_format["type"] = oai_format.pop("type", "reasoning") oai_format["type"] = oai_format.pop("type", "reasoning")
if "encrypted_content" in oai_format:
oai_format["encrypted_content"] = oai_format.pop("encrypted_content")
yield oai_format yield oai_format
i += 1 i += 1
continue continue
@ -594,13 +608,11 @@ def _consolidate_calls(
# If this really is the matching “result” collapse # If this really is the matching “result” collapse
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"): if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
if call_name == "web_search_call": if call_name == "web_search_call":
collapsed = { collapsed = {"id": current["id"]}
"id": current["id"],
"status": current["status"],
"type": "web_search_call",
}
if "action" in current: if "action" in current:
collapsed["action"] = current["action"] collapsed["action"] = current["action"]
collapsed["status"] = current["status"]
collapsed["type"] = "web_search_call"
if call_name == "code_interpreter_call": if call_name == "code_interpreter_call":
collapsed = {"id": current["id"]} collapsed = {"id": current["id"]}
@ -621,51 +633,50 @@ def _consolidate_calls(
yield nxt yield nxt
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: def _convert_from_v1_to_responses(
if not isinstance(message.content, list): content: list[types.ContentBlock], tool_calls: list[types.ToolCall]
return message ) -> list[dict[str, Any]]:
new_content: list = [] new_content: list = []
for block in message.content: for block in content:
if isinstance(block, dict): if block["type"] == "text" and "annotations" in block:
block_type = block.get("type") # Need a copy because were changing the annotations list
if block_type == "text" and "annotations" in block: new_block = dict(block)
# Need a copy because were changing the annotations list new_block["annotations"] = [
new_block = dict(block) _convert_annotation_from_v1(a) for a in block["annotations"]
new_block["annotations"] = [ ]
_convert_annotation_from_v1(a) for a in block["annotations"] new_content.append(new_block)
elif block["type"] == "tool_call":
new_block = {"type": "function_call", "call_id": block["id"]}
if "item_id" in block:
new_block["id"] = block["item_id"] # type: ignore[typeddict-item]
if "name" in block and "arguments" in block:
new_block["name"] = block["name"]
new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item]
else:
matching_tool_calls = [
call for call in tool_calls if call["id"] == block["id"]
] ]
new_content.append(new_block) if matching_tool_calls:
elif block_type == "tool_call": tool_call = matching_tool_calls[0]
new_block = {"type": "function_call", "call_id": block["id"]}
if "item_id" in block:
new_block["id"] = block["item_id"]
if "name" in block and "arguments" in block:
new_block["name"] = block["name"]
new_block["arguments"] = block["arguments"]
else:
tool_call = next(
call for call in message.tool_calls if call["id"] == block["id"]
)
if "name" not in block: if "name" not in block:
new_block["name"] = tool_call["name"] new_block["name"] = tool_call["name"]
if "arguments" not in block: if "arguments" not in block:
new_block["arguments"] = json.dumps(tool_call["args"]) new_block["arguments"] = json.dumps(tool_call["args"])
new_content.append(new_block) new_content.append(new_block)
elif ( elif (
is_data_content_block(block) is_data_content_block(cast(dict, block))
and block["type"] == "image" and block["type"] == "image"
and "base64" in block and "base64" in block
): and isinstance(block.get("id"), str)
new_block = {"type": "image_generation_call", "result": block["base64"]} and block["id"].startswith("ig_")
for extra_key in ("id", "status"): ):
if extra_key in block: new_block = {"type": "image_generation_call", "result": block["base64"]}
new_block[extra_key] = block[extra_key] for extra_key in ("id", "status"):
new_content.append(new_block) if extra_key in block:
elif block_type == "non_standard" and "value" in block: new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
new_content.append(block["value"]) new_content.append(new_block)
else: elif block["type"] == "non_standard" and "value" in block:
new_content.append(block) new_content.append(block["value"])
else: else:
new_content.append(block) new_content.append(block)
@ -679,4 +690,4 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
) )
) )
return message.model_copy(update={"content": new_content}) return new_content

File diff suppressed because it is too large Load Diff

View File

@ -56,6 +56,8 @@ langchain-tests = { path = "../../standard-tests", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
disable_error_code = ["typeddict-unknown-key"]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "transformers" module = "transformers"
ignore_missing_imports = true ignore_missing_imports = true

View File

@ -14,16 +14,24 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
MessageLikeRepresentation, MessageLikeRepresentation,
) )
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI, ChatOpenAIV1
MODEL_NAME = "gpt-4o-mini" MODEL_NAME = "gpt-4o-mini"
def _check_response(response: Optional[BaseMessage]) -> None: def _check_response(response: Optional[BaseMessage], output_version) -> None:
assert isinstance(response, AIMessage) if output_version == "v1":
assert isinstance(response, AIMessageV1) or isinstance(
response, AIMessageChunkV1
)
else:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list) assert isinstance(response.content, list)
for block in response.content: for block in response.content:
assert isinstance(block, dict) assert isinstance(block, dict)
@ -41,7 +49,10 @@ def _check_response(response: Optional[BaseMessage]) -> None:
for key in ["end_index", "start_index", "title", "type", "url"] for key in ["end_index", "start_index", "title", "type", "url"]
) )
text_content = response.text() if output_version == "v1":
text_content = response.text
else:
text_content = response.text()
assert isinstance(text_content, str) assert isinstance(text_content, str)
assert text_content assert text_content
assert response.usage_metadata assert response.usage_metadata
@ -56,22 +67,34 @@ def _check_response(response: Optional[BaseMessage]) -> None:
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version) if output_version == "v1":
llm = ChatOpenAIV1(model=MODEL_NAME)
else:
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
first_response = llm.invoke( first_response = llm.invoke(
"What was a positive news story from today?", "What was a positive news story from today?",
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
) )
_check_response(first_response) _check_response(first_response, output_version)
# Test streaming # Test streaming
full: Optional[BaseMessageChunk] = None if isinstance(llm, ChatOpenAIV1):
for chunk in llm.stream( full: Optional[AIMessageChunkV1] = None
"What was a positive news story from today?", for chunk in llm.stream(
tools=[{"type": "web_search_preview"}], "What was a positive news story from today?",
): tools=[{"type": "web_search_preview"}],
assert isinstance(chunk, AIMessageChunk) ):
full = chunk if full is None else full + chunk assert isinstance(chunk, AIMessageChunkV1)
_check_response(full) full = chunk if full is None else full + chunk
else:
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full, output_version)
# Use OpenAI's stateful API # Use OpenAI's stateful API
response = llm.invoke( response = llm.invoke(
@ -79,38 +102,26 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
previous_response_id=first_response.response_metadata["id"], previous_response_id=first_response.response_metadata["id"],
) )
_check_response(response) _check_response(response, output_version)
# Manually pass in chat history # Manually pass in chat history
response = llm.invoke( response = llm.invoke(
[ [
{ {"role": "user", "content": "What was a positive news story from today?"},
"role": "user",
"content": [
{
"type": "text",
"text": "What was a positive news story from today?",
}
],
},
first_response, first_response,
{ {"role": "user", "content": "what about a negative one"},
"role": "user",
"content": [{"type": "text", "text": "what about a negative one"}],
},
], ],
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
) )
_check_response(response) _check_response(response, output_version)
# Bind tool # Bind tool
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke( response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
"What was a positive news story from today?" "What was a positive news story from today?"
) )
_check_response(response) _check_response(response, output_version)
for msg in [first_response, full, response]: for msg in [first_response, full, response]:
assert isinstance(msg, AIMessage)
block_types = [block["type"] for block in msg.content] # type: ignore[index] block_types = [block["type"] for block in msg.content] # type: ignore[index]
if output_version == "responses/v1": if output_version == "responses/v1":
assert block_types == ["web_search_call", "text"] assert block_types == ["web_search_call", "text"]
@ -125,7 +136,7 @@ async def test_web_search_async() -> None:
"What was a positive news story from today?", "What was a positive news story from today?",
tools=[{"type": "web_search_preview"}], tools=[{"type": "web_search_preview"}],
) )
_check_response(response) _check_response(response, "v0")
assert response.response_metadata["status"] assert response.response_metadata["status"]
# Test streaming # Test streaming
@ -137,7 +148,7 @@ async def test_web_search_async() -> None:
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
_check_response(full) _check_response(full, "v0")
for msg in [response, full]: for msg in [response, full]:
assert msg.additional_kwargs["tool_outputs"] assert msg.additional_kwargs["tool_outputs"]
@ -148,8 +159,8 @@ async def test_web_search_async() -> None:
@pytest.mark.default_cassette("test_function_calling.yaml.gz") @pytest.mark.default_cassette("test_function_calling.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None: def test_function_calling(output_version: Literal["v0", "responses/v1"]) -> None:
def multiply(x: int, y: int) -> int: def multiply(x: int, y: int) -> int:
"""return x * y""" """return x * y"""
return x * y return x * y
@ -170,7 +181,33 @@ def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -
assert set(full.tool_calls[0]["args"]) == {"x", "y"} assert set(full.tool_calls[0]["args"]) == {"x", "y"}
response = bound_llm.invoke("What was a positive news story from today?") response = bound_llm.invoke("What was a positive news story from today?")
_check_response(response) _check_response(response, output_version)
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
@pytest.mark.vcr
def test_function_calling_v1() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
llm = ChatOpenAIV1(model=MODEL_NAME)
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
ai_msg = bound_llm.invoke("whats 5 * 4")
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "multiply"
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
full: Any = None
for chunk in bound_llm.stream("whats 5 * 4"):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert len(full.tool_calls) == 1
assert full.tool_calls[0]["name"] == "multiply"
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
response = bound_llm.invoke("What was a positive news story from today?")
_check_response(response, "v1")
class Foo(BaseModel): class Foo(BaseModel):
@ -183,10 +220,8 @@ class FooDict(TypedDict):
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz") @pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_parsed_pydantic_schema( def test_parsed_pydantic_schema(output_version: Literal["v0", "responses/v1"]) -> None:
output_version: Literal["v0", "responses/v1", "v1"],
) -> None:
llm = ChatOpenAI( llm = ChatOpenAI(
model=MODEL_NAME, use_responses_api=True, output_version=output_version model=MODEL_NAME, use_responses_api=True, output_version=output_version
) )
@ -206,6 +241,28 @@ def test_parsed_pydantic_schema(
assert parsed.response assert parsed.response
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
@pytest.mark.vcr
def test_parsed_pydantic_schema_v1() -> None:
llm = ChatOpenAIV1(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text))
assert parsed == response.parsed
assert parsed.response
# Test stream
full: Optional[AIMessageChunkV1] = None
chunks = []
for chunk in llm.stream("how are ya", response_format=Foo):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
chunks.append(chunk)
assert isinstance(full, AIMessageChunkV1)
parsed = Foo(**json.loads(full.text))
assert parsed == full.parsed
assert parsed.response
async def test_parsed_pydantic_schema_async() -> None: async def test_parsed_pydantic_schema_async() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = await llm.ainvoke("how are ya", response_format=Foo) response = await llm.ainvoke("how are ya", response_format=Foo)
@ -311,8 +368,8 @@ def test_function_calling_and_structured_output() -> None:
@pytest.mark.default_cassette("test_reasoning.yaml.gz") @pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None: def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI( llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version model="o4-mini", use_responses_api=True, output_version=output_version
) )
@ -337,6 +394,26 @@ def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
assert block_types == ["reasoning", "text"] assert block_types == ["reasoning", "text"]
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr
def test_reasoning_v1() -> None:
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
response = llm.invoke("Hello", reasoning={"effort": "low"})
assert isinstance(response, AIMessageV1)
# Test init params + streaming
llm = ChatOpenAIV1(model="o4-mini", reasoning={"effort": "low"})
full: Optional[AIMessageChunkV1] = None
for chunk in llm.stream("Hello"):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
for msg in [response, full]:
block_types = [block["type"] for block in msg.content]
assert block_types == ["reasoning", "text"]
def test_stateful_api() -> None: def test_stateful_api() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are you, my name is Bobo") response = llm.invoke("how are you, my name is Bobo")
@ -380,14 +457,14 @@ def test_file_search() -> None:
input_message = {"role": "user", "content": "What is deep research by OpenAI?"} input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
response = llm.invoke([input_message], tools=[tool]) response = llm.invoke([input_message], tools=[tool])
_check_response(response) _check_response(response, "v0")
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
for chunk in llm.stream([input_message], tools=[tool]): for chunk in llm.stream([input_message], tools=[tool]):
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
_check_response(full) _check_response(full, "v0")
next_message = {"role": "user", "content": "Thank you."} next_message = {"role": "user", "content": "Thank you."}
_ = llm.invoke([input_message, full, next_message]) _ = llm.invoke([input_message, full, next_message])
@ -395,9 +472,9 @@ def test_file_search() -> None:
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz") @pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_stream_reasoning_summary( def test_stream_reasoning_summary(
output_version: Literal["v0", "responses/v1", "v1"], output_version: Literal["v0", "responses/v1"],
) -> None: ) -> None:
llm = ChatOpenAI( llm = ChatOpenAI(
model="o4-mini", model="o4-mini",
@ -424,7 +501,8 @@ def test_stream_reasoning_summary(
assert isinstance(block["type"], str) assert isinstance(block["type"], str)
assert isinstance(block["text"], str) assert isinstance(block["text"], str)
assert block["text"] assert block["text"]
elif output_version == "responses/v1": else:
# output_version == "responses/v1"
reasoning = next( reasoning = next(
block block
for block in response_1.content for block in response_1.content
@ -438,18 +516,6 @@ def test_stream_reasoning_summary(
assert isinstance(block["type"], str) assert isinstance(block["type"], str)
assert isinstance(block["text"], str) assert isinstance(block["text"], str)
assert block["text"] assert block["text"]
else:
# v1
total_reasoning_blocks = 0
for block in response_1.content:
if block["type"] == "reasoning":
total_reasoning_blocks += 1
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
assert isinstance(block["reasoning"], str)
assert isinstance(block["index"], int)
assert (
total_reasoning_blocks > 1
) # This query typically generates multiple reasoning blocks
# Check we can pass back summaries # Check we can pass back summaries
message_2 = {"role": "user", "content": "Thank you."} message_2 = {"role": "user", "content": "Thank you."}
@ -457,10 +523,45 @@ def test_stream_reasoning_summary(
assert isinstance(response_2, AIMessage) assert isinstance(response_2, AIMessage)
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr
def test_stream_reasoning_summary_v1() -> None:
llm = ChatOpenAIV1(
model="o4-mini",
# Routes to Responses API if `reasoning` is set.
reasoning={"effort": "medium", "summary": "auto"},
)
message_1 = {
"role": "user",
"content": "What was the third tallest buliding in the year 2000?",
}
response_1: Optional[AIMessageChunkV1] = None
for chunk in llm.stream([message_1]):
assert isinstance(chunk, AIMessageChunkV1)
response_1 = chunk if response_1 is None else response_1 + chunk
assert isinstance(response_1, AIMessageChunkV1)
total_reasoning_blocks = 0
for block in response_1.content:
if block["type"] == "reasoning":
total_reasoning_blocks += 1
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
assert isinstance(block["reasoning"], str)
assert isinstance(block["index"], int)
assert (
total_reasoning_blocks > 1
) # This query typically generates multiple reasoning blocks
# Check we can pass back summaries
message_2 = {"role": "user", "content": "Thank you."}
response_2 = llm.invoke([message_1, response_1, message_2])
assert isinstance(response_2, AIMessageV1)
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz") @pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None: def test_code_interpreter(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI( llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version model="o4-mini", use_responses_api=True, output_version=output_version
) )
@ -473,33 +574,20 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
} }
response = llm_with_tools.invoke([input_message]) response = llm_with_tools.invoke([input_message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
_check_response(response) _check_response(response, output_version)
if output_version == "v0": if output_version == "v0":
tool_outputs = [ tool_outputs = [
item item
for item in response.additional_kwargs["tool_outputs"] for item in response.additional_kwargs["tool_outputs"]
if item["type"] == "code_interpreter_call" if item["type"] == "code_interpreter_call"
] ]
elif output_version == "responses/v1":
tool_outputs = [
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
]
else: else:
# v1 # responses/v1
tool_outputs = [ tool_outputs = [
item item
for item in response.content for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call" if isinstance(item, dict) and item["type"] == "code_interpreter_call"
] ]
code_interpreter_result = next(
item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
)
assert tool_outputs
assert code_interpreter_result
assert len(tool_outputs) == 1 assert len(tool_outputs) == 1
# Test streaming # Test streaming
@ -520,25 +608,65 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
for item in response.additional_kwargs["tool_outputs"] for item in response.additional_kwargs["tool_outputs"]
if item["type"] == "code_interpreter_call" if item["type"] == "code_interpreter_call"
] ]
elif output_version == "responses/v1": else:
# responses/v1
tool_outputs = [ tool_outputs = [
item item
for item in response.content for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call" if isinstance(item, dict) and item["type"] == "code_interpreter_call"
] ]
else: assert tool_outputs
code_interpreter_call = next(
item # Test we can pass back in
for item in response.content next_message = {"role": "user", "content": "Please add more comments to the code."}
if isinstance(item, dict) and item["type"] == "code_interpreter_call" _ = llm_with_tools.invoke([input_message, full, next_message])
)
code_interpreter_result = next(
item @pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
for item in response.content @pytest.mark.vcr
if isinstance(item, dict) and item["type"] == "code_interpreter_result" def test_code_interpreter_v1() -> None:
) llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
assert code_interpreter_call llm_with_tools = llm.bind_tools(
assert code_interpreter_result [{"type": "code_interpreter", "container": {"type": "auto"}}]
)
input_message = {
"role": "user",
"content": "Write and run code to answer the question: what is 3^3?",
}
response = llm_with_tools.invoke([input_message])
assert isinstance(response, AIMessageV1)
_check_response(response, "v1")
tool_outputs = [
item for item in response.content if item["type"] == "code_interpreter_call"
]
code_interpreter_result = next(
item for item in response.content if item["type"] == "code_interpreter_result"
)
assert tool_outputs
assert code_interpreter_result
assert len(tool_outputs) == 1
# Test streaming
# Use same container
container_id = tool_outputs[0]["container_id"]
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": container_id}]
)
full: Optional[AIMessageChunkV1] = None
for chunk in llm_with_tools.stream([input_message]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
code_interpreter_call = next(
item for item in full.content if item["type"] == "code_interpreter_call"
)
code_interpreter_result = next(
item for item in full.content if item["type"] == "code_interpreter_result"
)
assert code_interpreter_call
assert code_interpreter_result
assert tool_outputs assert tool_outputs
# Test we can pass back in # Test we can pass back in
@ -634,9 +762,59 @@ def test_mcp_builtin_zdr() -> None:
_ = llm_with_tools.invoke([input_message, full, approval_message]) _ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz")
@pytest.mark.vcr
def test_mcp_builtin_zdr_v1() -> None:
llm = ChatOpenAIV1(
model="o4-mini", store=False, include=["reasoning.encrypted_content"]
)
llm_with_tools = llm.bind_tools(
[
{
"type": "mcp",
"server_label": "deepwiki",
"server_url": "https://mcp.deepwiki.com/mcp",
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
}
]
)
input_message = {
"role": "user",
"content": (
"What transport protocols does the 2025-03-26 version of the MCP spec "
"support?"
),
}
full: Optional[AIMessageChunkV1] = None
for chunk in llm_with_tools.stream([input_message]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert all(isinstance(block, dict) for block in full.content)
approval_message = HumanMessageV1(
[
{
"type": "non_standard",
"value": {
"type": "mcp_approval_response",
"approve": True,
"approval_request_id": block["value"]["id"], # type: ignore[index]
},
}
for block in full.content
if block["type"] == "non_standard"
and block["value"]["type"] == "mcp_approval_request" # type: ignore[index]
]
)
_ = llm_with_tools.invoke([input_message, full, approval_message])
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz") @pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_streaming(output_version: str) -> None: def test_image_generation_streaming(output_version: str) -> None:
"""Test image generation streaming.""" """Test image generation streaming."""
llm = ChatOpenAI( llm = ChatOpenAI(
@ -710,9 +888,52 @@ def test_image_generation_streaming(output_version: str) -> None:
assert set(standard_keys).issubset(tool_output.keys()) assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr
def test_image_generation_streaming_v1() -> None:
"""Test image generation streaming."""
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
tool = {
"type": "image_generation",
"quality": "low",
"output_format": "jpeg",
"output_compression": 100,
"size": "1024x1024",
}
expected_keys = {
# Standard
"type",
"base64",
"mime_type",
"id",
"index",
# OpenAI-specific
"background",
"output_format",
"quality",
"revised_prompt",
"size",
"status",
}
full: Optional[AIMessageChunkV1] = None
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
complete_ai_message = cast(AIMessageChunkV1, full)
tool_output = next(
block
for block in complete_ai_message.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(expected_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz") @pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr @pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"]) @pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_multi_turn(output_version: str) -> None: def test_image_generation_multi_turn(output_version: str) -> None:
"""Test multi-turn editing of image generation by passing in history.""" """Test multi-turn editing of image generation by passing in history."""
# Test multi-turn # Test multi-turn
@ -735,7 +956,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
] ]
ai_message = llm_with_tools.invoke(chat_history) ai_message = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message, AIMessage) assert isinstance(ai_message, AIMessage)
_check_response(ai_message) _check_response(ai_message, output_version)
expected_keys = { expected_keys = {
"id", "id",
@ -801,7 +1022,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
ai_message2 = llm_with_tools.invoke(chat_history) ai_message2 = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message2, AIMessage) assert isinstance(ai_message2, AIMessage)
_check_response(ai_message2) _check_response(ai_message2, output_version)
if output_version == "v0": if output_version == "v0":
tool_output = ai_message2.additional_kwargs["tool_outputs"][0] tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
@ -821,3 +1042,76 @@ def test_image_generation_multi_turn(output_version: str) -> None:
if isinstance(block, dict) and block["type"] == "image" if isinstance(block, dict) and block["type"] == "image"
) )
assert set(standard_keys).issubset(tool_output.keys()) assert set(standard_keys).issubset(tool_output.keys())
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr
def test_image_generation_multi_turn_v1() -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
# Test invocation
tool = {
"type": "image_generation",
"quality": "low",
"output_format": "jpeg",
"output_compression": 100,
"size": "1024x1024",
}
llm_with_tools = llm.bind_tools([tool])
chat_history: list[MessageLikeRepresentation] = [
{"role": "user", "content": "Draw a random short word in green font."}
]
ai_message = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message, AIMessageV1)
_check_response(ai_message, "v1")
expected_keys = {
# Standard
"type",
"base64",
"mime_type",
"id",
# OpenAI-specific
"background",
"output_format",
"quality",
"revised_prompt",
"size",
"status",
}
standard_keys = {"type", "base64", "id", "status"}
tool_output = next(
block
for block in ai_message.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(standard_keys).issubset(tool_output.keys())
chat_history.extend(
[
# AI message with tool output
ai_message,
# New request
{
"role": "user",
"content": (
"Now, change the font to blue. Keep the word and everything else "
"the same."
),
},
]
)
ai_message2 = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message2, AIMessageV1)
_check_response(ai_message2, "v1")
tool_output = next(
block
for block in ai_message2.content
if isinstance(block, dict) and block["type"] == "image"
)
assert set(expected_keys).issubset(tool_output.keys())