mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
feat(openai): v1 message format support (#32296)
This commit is contained in:
parent
7166adce1f
commit
c15e55b33c
@ -2,8 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cache
|
||||
from typing import (
|
||||
@ -26,7 +25,6 @@ from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
@ -166,7 +164,6 @@ class BaseLanguageModel(
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: list[PromptValue],
|
||||
@ -201,7 +198,6 @@ class BaseLanguageModel(
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: list[PromptValue],
|
||||
@ -245,7 +241,6 @@ class BaseLanguageModel(
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -266,7 +261,6 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -291,7 +285,6 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -312,7 +305,6 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -368,33 +360,6 @@ class BaseLanguageModel(
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
@ -55,12 +55,11 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -222,23 +221,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
|
||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||
if isinstance(message.content, str):
|
||||
content: list[types.ContentBlock] = []
|
||||
if message.content:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content.append(tool_call)
|
||||
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
usage_metadata=message.usage_metadata,
|
||||
response_metadata=message.response_metadata,
|
||||
)
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for chat models.
|
||||
|
||||
@ -1370,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from operator import itemgetter
|
||||
@ -38,11 +39,14 @@ from langchain_core.language_models.base import (
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_to_messages_v1
|
||||
from langchain_core.messages.utils import (
|
||||
_convert_from_v1_message,
|
||||
convert_to_messages_v1,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
@ -735,7 +739,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
*,
|
||||
tool_choice: Optional[Union[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
||||
"""Bind tools to the model.
|
||||
|
||||
Args:
|
||||
@ -899,6 +903,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
messages_v0 = [_convert_from_v1_message(message) for message in messages]
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages_v0)
|
||||
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||
|
@ -706,6 +706,7 @@ ToolContentBlock = Union[
|
||||
ContentBlock = Union[
|
||||
TextContentBlock,
|
||||
ToolCall,
|
||||
InvalidToolCall,
|
||||
ReasoningContentBlock,
|
||||
NonStandardContentBlock,
|
||||
DataContentBlock,
|
||||
|
@ -384,38 +384,37 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||
Returns:
|
||||
BaseMessage: Converted message instance.
|
||||
"""
|
||||
# type ignores here are because AIMessageV1.content is a list of dicts.
|
||||
# AIMessageV0.content expects str or list[str | dict].
|
||||
content = cast("Union[str, list[str | dict]]", message.content)
|
||||
if isinstance(message, AIMessageV1):
|
||||
return AIMessage(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
tool_calls=message.tool_calls,
|
||||
response_metadata=message.response_metadata,
|
||||
response_metadata=cast("dict", message.response_metadata),
|
||||
)
|
||||
if isinstance(message, AIMessageChunkV1):
|
||||
return AIMessageChunk(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
tool_call_chunks=message.tool_call_chunks,
|
||||
response_metadata=message.response_metadata,
|
||||
response_metadata=cast("dict", message.response_metadata),
|
||||
)
|
||||
if isinstance(message, HumanMessageV1):
|
||||
return HumanMessage(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
)
|
||||
if isinstance(message, SystemMessageV1):
|
||||
return SystemMessage(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
content=content,
|
||||
id=message.id,
|
||||
)
|
||||
if isinstance(message, ToolMessageV1):
|
||||
return ToolMessage(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
content=content,
|
||||
id=message.id,
|
||||
)
|
||||
message = f"Unsupported message type: {type(message)}"
|
||||
@ -501,7 +500,10 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
||||
ValueError: if the message dict does not contain the required keys.
|
||||
"""
|
||||
if isinstance(message, MessageV1Types):
|
||||
message_ = message
|
||||
if isinstance(message, AIMessageChunkV1):
|
||||
message_ = message.to_message()
|
||||
else:
|
||||
message_ = message
|
||||
elif isinstance(message, str):
|
||||
message_ = _create_message_from_message_type_v1("human", message)
|
||||
elif isinstance(message, Sequence) and len(message) == 2:
|
||||
|
@ -5,6 +5,8 @@ import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import langchain_core.messages.content_blocks as types
|
||||
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
||||
from langchain_core.messages.base import merge_content
|
||||
@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str:
|
||||
return id_val or str(uuid.uuid4())
|
||||
|
||||
|
||||
class Provider(TypedDict):
|
||||
"""Information about the provider that generated the message.
|
||||
class ResponseMetadata(TypedDict, total=False):
|
||||
"""Metadata about the response from the AI provider.
|
||||
|
||||
Contains metadata about the AI provider and model used to generate content.
|
||||
Contains additional information returned by the provider, such as
|
||||
response headers, service tiers, log probabilities, system fingerprints, etc.
|
||||
|
||||
Attributes:
|
||||
name: Name and version of the provider that created the content block.
|
||||
model_name: Name of the model that generated the content block.
|
||||
Extra keys are permitted from what is typed here.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""Name and version of the provider that created the content block."""
|
||||
model_provider: str
|
||||
"""Name and version of the provider that created the message (e.g., openai)."""
|
||||
|
||||
model_name: str
|
||||
"""Name of the model that generated the content block."""
|
||||
"""Name of the model that generated the message."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -91,21 +93,29 @@ class AIMessage:
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts."""
|
||||
|
||||
response_metadata: dict = field(default_factory=dict)
|
||||
response_metadata: ResponseMetadata = field(
|
||||
default_factory=lambda: ResponseMetadata()
|
||||
)
|
||||
"""Metadata about the response.
|
||||
|
||||
This field should include non-standard data returned by the provider, such as
|
||||
response headers, service tiers, or log probabilities.
|
||||
"""
|
||||
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||
"""Auto-parsed message contents, if applicable."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
lc_version: str = "v1",
|
||||
response_metadata: Optional[dict] = None,
|
||||
response_metadata: Optional[ResponseMetadata] = None,
|
||||
usage_metadata: Optional[UsageMetadata] = None,
|
||||
tool_calls: Optional[list[types.ToolCall]] = None,
|
||||
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
):
|
||||
"""Initialize an AI message.
|
||||
|
||||
@ -116,6 +126,11 @@ class AIMessage:
|
||||
lc_version: Encoding version for the message.
|
||||
response_metadata: Optional metadata about the response.
|
||||
usage_metadata: Optional metadata about token usage.
|
||||
tool_calls: Optional list of tool calls made by the AI. Tool calls should
|
||||
generally be included in message content. If passed on init, they will
|
||||
be added to the content list.
|
||||
invalid_tool_calls: Optional list of tool calls that failed validation.
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
@ -126,13 +141,27 @@ class AIMessage:
|
||||
self.name = name
|
||||
self.lc_version = lc_version
|
||||
self.usage_metadata = usage_metadata
|
||||
self.parsed = parsed
|
||||
if response_metadata is None:
|
||||
self.response_metadata = {}
|
||||
else:
|
||||
self.response_metadata = response_metadata
|
||||
|
||||
self._tool_calls: list[types.ToolCall] = []
|
||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
# Add tool calls to content if provided on init
|
||||
if tool_calls:
|
||||
content_tool_calls = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if block["type"] == "tool_call" and "id" in block
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||
continue
|
||||
self.content.append(tool_call)
|
||||
self._tool_calls = [
|
||||
block for block in self.content if block["type"] == "tool_call"
|
||||
]
|
||||
self.invalid_tool_calls = invalid_tool_calls or []
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
@ -150,7 +179,7 @@ class AIMessage:
|
||||
tool_calls = [block for block in self.content if block["type"] == "tool_call"]
|
||||
if tool_calls:
|
||||
self._tool_calls = tool_calls
|
||||
return self._tool_calls
|
||||
return [block for block in self.content if block["type"] == "tool_call"]
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: list[types.ToolCall]) -> None:
|
||||
@ -202,13 +231,16 @@ class AIMessageChunk:
|
||||
These data represent incremental usage statistics, as opposed to a running total.
|
||||
"""
|
||||
|
||||
response_metadata: dict = field(init=False)
|
||||
response_metadata: ResponseMetadata = field(init=False)
|
||||
"""Metadata about the response chunk.
|
||||
|
||||
This field should include non-standard data returned by the provider, such as
|
||||
response headers, service tiers, or log probabilities.
|
||||
"""
|
||||
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||
"""Auto-parsed message contents, if applicable."""
|
||||
|
||||
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
||||
|
||||
def __init__(
|
||||
@ -217,9 +249,10 @@ class AIMessageChunk:
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
lc_version: str = "v1",
|
||||
response_metadata: Optional[dict] = None,
|
||||
response_metadata: Optional[ResponseMetadata] = None,
|
||||
usage_metadata: Optional[UsageMetadata] = None,
|
||||
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
):
|
||||
"""Initialize an AI message.
|
||||
|
||||
@ -231,6 +264,7 @@ class AIMessageChunk:
|
||||
response_metadata: Optional metadata about the response.
|
||||
usage_metadata: Optional metadata about token usage.
|
||||
tool_call_chunks: Optional list of partial tool call data.
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content, "index": 0}]
|
||||
@ -241,6 +275,7 @@ class AIMessageChunk:
|
||||
self.name = name
|
||||
self.lc_version = lc_version
|
||||
self.usage_metadata = usage_metadata
|
||||
self.parsed = parsed
|
||||
if response_metadata is None:
|
||||
self.response_metadata = {}
|
||||
else:
|
||||
@ -251,7 +286,7 @@ class AIMessageChunk:
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
self._tool_calls: list[types.ToolCall] = []
|
||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
self.invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
self._init_tool_calls()
|
||||
|
||||
def _init_tool_calls(self) -> None:
|
||||
@ -264,7 +299,7 @@ class AIMessageChunk:
|
||||
ValueError: If the tool call chunks are malformed.
|
||||
"""
|
||||
self._tool_calls = []
|
||||
self._invalid_tool_calls = []
|
||||
self.invalid_tool_calls = []
|
||||
if not self.tool_call_chunks:
|
||||
if self._tool_calls:
|
||||
self.tool_call_chunks = [
|
||||
@ -276,14 +311,14 @@ class AIMessageChunk:
|
||||
)
|
||||
for tc in self._tool_calls
|
||||
]
|
||||
if self._invalid_tool_calls:
|
||||
if self.invalid_tool_calls:
|
||||
tool_call_chunks = self.tool_call_chunks
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in self._invalid_tool_calls
|
||||
for tc in self.invalid_tool_calls
|
||||
]
|
||||
)
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
@ -294,9 +329,9 @@ class AIMessageChunk:
|
||||
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
|
||||
invalid_tool_calls.append(
|
||||
create_invalid_tool_call(
|
||||
name=chunk["name"],
|
||||
args=chunk["args"],
|
||||
id=chunk["id"],
|
||||
name=chunk.get("name", ""),
|
||||
args=chunk.get("args", ""),
|
||||
id=chunk.get("id", ""),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
@ -307,9 +342,9 @@ class AIMessageChunk:
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
name=chunk["name"] or "",
|
||||
name=chunk.get("name", ""),
|
||||
args=args_,
|
||||
id=chunk["id"],
|
||||
id=chunk.get("id", ""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -317,7 +352,7 @@ class AIMessageChunk:
|
||||
except Exception:
|
||||
add_chunk_to_invalid_tool_calls(chunk)
|
||||
self._tool_calls = tool_calls
|
||||
self._invalid_tool_calls = invalid_tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
@ -361,6 +396,20 @@ class AIMessageChunk:
|
||||
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
def to_message(self) -> "AIMessage":
|
||||
"""Convert this AIMessageChunk to an AIMessage."""
|
||||
return AIMessage(
|
||||
content=self.content,
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
lc_version=self.lc_version,
|
||||
response_metadata=self.response_metadata,
|
||||
usage_metadata=self.usage_metadata,
|
||||
tool_calls=self.tool_calls,
|
||||
invalid_tool_calls=self.invalid_tool_calls,
|
||||
parsed=self.parsed,
|
||||
)
|
||||
|
||||
|
||||
def add_ai_message_chunks(
|
||||
left: AIMessageChunk, *others: AIMessageChunk
|
||||
@ -373,7 +422,8 @@ def add_ai_message_chunks(
|
||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||
)
|
||||
response_metadata = merge_dicts(
|
||||
left.response_metadata, *(o.response_metadata for o in others)
|
||||
cast("dict", left.response_metadata),
|
||||
*(cast("dict", o.response_metadata) for o in others),
|
||||
)
|
||||
|
||||
# Merge tool call chunks
|
||||
@ -400,6 +450,15 @@ def add_ai_message_chunks(
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
# Parsed
|
||||
# 'parsed' always represents an aggregation not an incremental value, so the last
|
||||
# non-null value is kept.
|
||||
parsed = None
|
||||
for m in reversed([left, *others]):
|
||||
if m.parsed is not None:
|
||||
parsed = m.parsed
|
||||
break
|
||||
|
||||
chunk_id = None
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first non-run-* id
|
||||
@ -417,8 +476,9 @@ def add_ai_message_chunks(
|
||||
return left.__class__(
|
||||
content=cast("list[types.ContentBlock]", content),
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
response_metadata=response_metadata,
|
||||
response_metadata=cast("ResponseMetadata", response_metadata),
|
||||
usage_metadata=usage_metadata,
|
||||
parsed=parsed,
|
||||
id=chunk_id,
|
||||
)
|
||||
|
||||
@ -455,19 +515,25 @@ class HumanMessage:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize a human message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
name: Optional human-readable name for the message.
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.name = name
|
||||
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the message.
|
||||
@ -497,20 +563,47 @@ class SystemMessage:
|
||||
content: list[types.ContentBlock]
|
||||
type: Literal["system"] = "system"
|
||||
|
||||
name: Optional[str] = None
|
||||
"""An optional name for the message.
|
||||
|
||||
This can be used to provide a human-readable name for the message.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
"""
|
||||
|
||||
custom_role: Optional[str] = None
|
||||
"""If provided, a custom role for the system message.
|
||||
|
||||
Example: ``"developer"``.
|
||||
|
||||
Integration packages may use this field to assign the system message role if it
|
||||
contains a recognized value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
custom_role: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize a system message.
|
||||
"""Initialize a human message.
|
||||
|
||||
Args:
|
||||
content: System instructions as string or list of content blocks.
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
custom_role: If provided, a custom role for the system message.
|
||||
name: Optional human-readable name for the message.
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.custom_role = custom_role
|
||||
self.name = name
|
||||
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the system message."""
|
||||
@ -537,11 +630,51 @@ class ToolMessage:
|
||||
|
||||
id: str
|
||||
tool_call_id: str
|
||||
content: list[dict[str, Any]]
|
||||
content: list[types.ContentBlock]
|
||||
artifact: Optional[Any] = None # App-side payload not for the model
|
||||
|
||||
name: Optional[str] = None
|
||||
"""An optional name for the message.
|
||||
|
||||
This can be used to provide a human-readable name for the message.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
"""
|
||||
|
||||
status: Literal["success", "error"] = "success"
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
tool_call_id: str,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
artifact: Optional[Any] = None,
|
||||
status: Literal["success", "error"] = "success",
|
||||
):
|
||||
"""Initialize a human message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
tool_call_id: ID of the tool call this message responds to.
|
||||
id: Optional unique identifier for the message.
|
||||
name: Optional human-readable name for the message.
|
||||
artifact: Optional app-side payload not intended for the model.
|
||||
status: Execution status ("success" or "error").
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
self.tool_call_id = tool_call_id
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.name = name
|
||||
self.artifact = artifact
|
||||
self.status = status
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the tool message."""
|
||||
|
@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
|
||||
from langchain_core.messages.tool import invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
@ -26,7 +26,7 @@ def parse_tool_call(
|
||||
partial: bool = False,
|
||||
strict: bool = False,
|
||||
return_id: bool = True,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> Optional[ToolCall]:
|
||||
"""Parse a single tool call.
|
||||
|
||||
Args:
|
||||
|
@ -8,17 +8,65 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
from typing import Literal, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, overload
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
|
||||
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||
|
||||
|
||||
def _convert_to_v1(message: BaseMessage) -> MessageV1:
|
||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||
if isinstance(message.content, str):
|
||||
content: list[types.ContentBlock] = []
|
||||
if message.content:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
else:
|
||||
content = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
content.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
content.append(cast("types.ContentBlock", block))
|
||||
else:
|
||||
pass
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
return HumanMessageV1(content=content)
|
||||
if isinstance(message, AIMessage):
|
||||
for tool_call in message.tool_calls:
|
||||
content.append(tool_call)
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
usage_metadata=message.usage_metadata,
|
||||
response_metadata=cast("ResponseMetadata", message.response_metadata),
|
||||
tool_calls=message.tool_calls,
|
||||
)
|
||||
if isinstance(message, SystemMessage):
|
||||
return SystemMessageV1(content=content)
|
||||
if isinstance(message, ToolMessage):
|
||||
return ToolMessageV1(
|
||||
tool_call_id=message.tool_call_id,
|
||||
content=content,
|
||||
artifact=message.artifact,
|
||||
)
|
||||
error_message = f"Unsupported message type: {type(message)}"
|
||||
raise TypeError(error_message)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as messages."""
|
||||
if output_version == "v1":
|
||||
return [HumanMessageV1(content=self.text)]
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as a list of messages.
|
||||
|
||||
Args:
|
||||
output_version: The output version, either "v0" (default) or "v1".
|
||||
"""
|
||||
if output_version == "v1":
|
||||
return [_convert_to_v1(m) for m in self.messages]
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
|
||||
"""Return prompt (image URL) as string."""
|
||||
return self.image_url["url"]
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
@overload
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, output_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt (image URL) as messages."""
|
||||
if output_version == "v1":
|
||||
block: types.ImageContentBlock = {
|
||||
"type": "image",
|
||||
"url": self.image_url["url"],
|
||||
}
|
||||
if "detail" in self.image_url:
|
||||
block["detail"] = self.image_url["detail"]
|
||||
return [HumanMessageV1(content=[block])]
|
||||
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
||||
|
||||
|
||||
|
@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
enable_error_code = "deprecated"
|
||||
disable_error_code = ["typeddict-unknown-key"]
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = "False"
|
||||
|
@ -34,6 +34,7 @@ from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.utils._merge import merge_lists
|
||||
|
||||
@ -197,7 +198,7 @@ def test_message_chunks() -> None:
|
||||
assert (meaningful_id + default_id).id == "msg_def456"
|
||||
|
||||
|
||||
def test_message_chunks_v2() -> None:
|
||||
def test_message_chunks_v1() -> None:
|
||||
left = AIMessageChunkV1("foo ", id="abc")
|
||||
right = AIMessageChunkV1("bar")
|
||||
expected = AIMessageChunkV1("foo bar", id="abc")
|
||||
@ -230,7 +231,19 @@ def test_message_chunks_v2() -> None:
|
||||
)
|
||||
],
|
||||
)
|
||||
assert one + two + three == expected
|
||||
result = one + two + three
|
||||
assert result == expected
|
||||
|
||||
assert result.to_message() == AIMessageV1(
|
||||
content=[
|
||||
{
|
||||
"name": "tool1",
|
||||
"args": {"arg1": "value}"},
|
||||
"id": "1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
AIMessageChunkV1(
|
||||
@ -1326,6 +1339,7 @@ def test_known_block_types() -> None:
|
||||
"text",
|
||||
"text-plain",
|
||||
"tool_call",
|
||||
"invalid_tool_call",
|
||||
"reasoning",
|
||||
"non_standard",
|
||||
"image",
|
||||
|
@ -1,10 +1,11 @@
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAIV1
|
||||
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||
|
||||
__all__ = [
|
||||
"OpenAI",
|
||||
"ChatOpenAI",
|
||||
"ChatOpenAIV1",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
|
@ -1,4 +1,5 @@
|
||||
from langchain_openai.chat_models.azure import AzureChatOpenAI
|
||||
from langchain_openai.chat_models.base import ChatOpenAI
|
||||
from langchain_openai.chat_models.base_v1 import ChatOpenAI as ChatOpenAIV1
|
||||
|
||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]
|
||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAIV1"]
|
||||
|
@ -66,11 +66,14 @@ For backwards compatibility, this module provides functions to convert between t
|
||||
formats. The functions are used internally by ChatOpenAI.
|
||||
""" # noqa: E501
|
||||
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Literal, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
|
||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||
|
||||
@ -289,25 +292,21 @@ def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessa
|
||||
return cast(AIMessageChunk, result)
|
||||
|
||||
|
||||
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
|
||||
def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1:
|
||||
"""Convert a v1 message to the Chat Completions format."""
|
||||
if isinstance(message.content, list):
|
||||
new_content: list = []
|
||||
for block in message.content:
|
||||
if isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
# Strip annotations
|
||||
new_content.append({"type": "text", "text": block["text"]})
|
||||
elif block_type in ("reasoning", "tool_call"):
|
||||
pass
|
||||
else:
|
||||
new_content.append(block)
|
||||
else:
|
||||
new_content.append(block)
|
||||
return message.model_copy(update={"content": new_content})
|
||||
new_content: list[types.ContentBlock] = []
|
||||
for block in message.content:
|
||||
if block["type"] == "text":
|
||||
# Strip annotations
|
||||
new_content.append({"type": "text", "text": block["text"]})
|
||||
elif block["type"] in ("reasoning", "tool_call"):
|
||||
pass
|
||||
else:
|
||||
new_content.append(block)
|
||||
new_message = copy.copy(message)
|
||||
new_message.content = new_content
|
||||
|
||||
return message
|
||||
return new_message
|
||||
|
||||
|
||||
# v1 / Responses
|
||||
@ -319,17 +318,18 @@ def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
|
||||
for field in ("end_index", "start_index", "title"):
|
||||
if field in annotation:
|
||||
url_citation[field] = annotation[field]
|
||||
url_citation["type"] = "url_citation"
|
||||
url_citation["type"] = "citation"
|
||||
url_citation["url"] = annotation["url"]
|
||||
return url_citation
|
||||
|
||||
elif annotation_type == "file_citation":
|
||||
document_citation = {"type": "document_citation"}
|
||||
document_citation = {"type": "citation"}
|
||||
if "filename" in annotation:
|
||||
document_citation["title"] = annotation["filename"]
|
||||
for field in ("file_id", "index"): # OpenAI-specific
|
||||
if field in annotation:
|
||||
document_citation[field] = annotation[field]
|
||||
if "file_id" in annotation:
|
||||
document_citation["file_id"] = annotation["file_id"]
|
||||
if "index" in annotation:
|
||||
document_citation["file_index"] = annotation["index"]
|
||||
return document_citation
|
||||
|
||||
# TODO: standardise container_file_citation?
|
||||
@ -367,13 +367,15 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
|
||||
yield new_block
|
||||
|
||||
|
||||
def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
||||
def _convert_to_v1_from_responses(
|
||||
content: list[dict[str, Any]],
|
||||
tool_calls: Optional[list[types.ToolCall]] = None,
|
||||
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
|
||||
) -> list[types.ContentBlock]:
|
||||
"""Mutate a Responses message to v1 format."""
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
def _iter_blocks() -> Iterable[dict[str, Any]]:
|
||||
for block in message.content:
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
@ -409,13 +411,24 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
||||
yield new_block
|
||||
|
||||
elif block_type == "function_call":
|
||||
new_block = {"type": "tool_call", "id": block.get("call_id", "")}
|
||||
if "id" in block:
|
||||
new_block["item_id"] = block["id"]
|
||||
for extra_key in ("arguments", "name", "index"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key]
|
||||
yield new_block
|
||||
new_block = None
|
||||
call_id = block.get("call_id", "")
|
||||
if call_id:
|
||||
for tool_call in tool_calls or []:
|
||||
if tool_call.get("id") == call_id:
|
||||
new_block = tool_call.copy()
|
||||
break
|
||||
else:
|
||||
for invalid_tool_call in invalid_tool_calls or []:
|
||||
if invalid_tool_call.get("id") == call_id:
|
||||
new_block = invalid_tool_call.copy()
|
||||
break
|
||||
if new_block:
|
||||
if "id" in block:
|
||||
new_block["item_id"] = block["id"]
|
||||
if "index" in block:
|
||||
new_block["index"] = block["index"]
|
||||
yield new_block
|
||||
|
||||
elif block_type == "web_search_call":
|
||||
web_search_call = {"type": "web_search_call", "id": block["id"]}
|
||||
@ -485,28 +498,26 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
||||
new_block["index"] = new_block["value"].pop("index")
|
||||
yield new_block
|
||||
|
||||
# Replace the list with the fully converted one
|
||||
message.content = list(_iter_blocks())
|
||||
|
||||
return message
|
||||
return list(_iter_blocks())
|
||||
|
||||
|
||||
def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]:
|
||||
annotation_type = annotation.get("type")
|
||||
def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
|
||||
if annotation["type"] == "citation":
|
||||
if "url" in annotation:
|
||||
return {**annotation, "type": "url_citation"}
|
||||
|
||||
if annotation_type == "document_citation":
|
||||
new_ann: dict[str, Any] = {"type": "file_citation"}
|
||||
|
||||
if "title" in annotation:
|
||||
new_ann["filename"] = annotation["title"]
|
||||
|
||||
for fld in ("file_id", "index"):
|
||||
if fld in annotation:
|
||||
new_ann[fld] = annotation[fld]
|
||||
if "file_id" in annotation:
|
||||
new_ann["file_id"] = annotation["file_id"]
|
||||
if "file_index" in annotation:
|
||||
new_ann["index"] = annotation["file_index"]
|
||||
|
||||
return new_ann
|
||||
|
||||
elif annotation_type == "non_standard_annotation":
|
||||
elif annotation["type"] == "non_standard_annotation":
|
||||
return annotation["value"]
|
||||
|
||||
else:
|
||||
@ -528,7 +539,10 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
|
||||
elif "reasoning" not in block and "summary" not in block:
|
||||
# {"type": "reasoning", "id": "rs_..."}
|
||||
oai_format = {**block, "summary": []}
|
||||
# Update key order
|
||||
oai_format["type"] = oai_format.pop("type", "reasoning")
|
||||
if "encrypted_content" in oai_format:
|
||||
oai_format["encrypted_content"] = oai_format.pop("encrypted_content")
|
||||
yield oai_format
|
||||
i += 1
|
||||
continue
|
||||
@ -594,13 +608,11 @@ def _consolidate_calls(
|
||||
# If this really is the matching “result” – collapse
|
||||
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
|
||||
if call_name == "web_search_call":
|
||||
collapsed = {
|
||||
"id": current["id"],
|
||||
"status": current["status"],
|
||||
"type": "web_search_call",
|
||||
}
|
||||
collapsed = {"id": current["id"]}
|
||||
if "action" in current:
|
||||
collapsed["action"] = current["action"]
|
||||
collapsed["status"] = current["status"]
|
||||
collapsed["type"] = "web_search_call"
|
||||
|
||||
if call_name == "code_interpreter_call":
|
||||
collapsed = {"id": current["id"]}
|
||||
@ -621,51 +633,50 @@ def _consolidate_calls(
|
||||
yield nxt
|
||||
|
||||
|
||||
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
||||
if not isinstance(message.content, list):
|
||||
return message
|
||||
|
||||
def _convert_from_v1_to_responses(
|
||||
content: list[types.ContentBlock], tool_calls: list[types.ToolCall]
|
||||
) -> list[dict[str, Any]]:
|
||||
new_content: list = []
|
||||
for block in message.content:
|
||||
if isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text" and "annotations" in block:
|
||||
# Need a copy because we’re changing the annotations list
|
||||
new_block = dict(block)
|
||||
new_block["annotations"] = [
|
||||
_convert_annotation_from_v1(a) for a in block["annotations"]
|
||||
for block in content:
|
||||
if block["type"] == "text" and "annotations" in block:
|
||||
# Need a copy because we’re changing the annotations list
|
||||
new_block = dict(block)
|
||||
new_block["annotations"] = [
|
||||
_convert_annotation_from_v1(a) for a in block["annotations"]
|
||||
]
|
||||
new_content.append(new_block)
|
||||
elif block["type"] == "tool_call":
|
||||
new_block = {"type": "function_call", "call_id": block["id"]}
|
||||
if "item_id" in block:
|
||||
new_block["id"] = block["item_id"] # type: ignore[typeddict-item]
|
||||
if "name" in block and "arguments" in block:
|
||||
new_block["name"] = block["name"]
|
||||
new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item]
|
||||
else:
|
||||
matching_tool_calls = [
|
||||
call for call in tool_calls if call["id"] == block["id"]
|
||||
]
|
||||
new_content.append(new_block)
|
||||
elif block_type == "tool_call":
|
||||
new_block = {"type": "function_call", "call_id": block["id"]}
|
||||
if "item_id" in block:
|
||||
new_block["id"] = block["item_id"]
|
||||
if "name" in block and "arguments" in block:
|
||||
new_block["name"] = block["name"]
|
||||
new_block["arguments"] = block["arguments"]
|
||||
else:
|
||||
tool_call = next(
|
||||
call for call in message.tool_calls if call["id"] == block["id"]
|
||||
)
|
||||
if matching_tool_calls:
|
||||
tool_call = matching_tool_calls[0]
|
||||
if "name" not in block:
|
||||
new_block["name"] = tool_call["name"]
|
||||
if "arguments" not in block:
|
||||
new_block["arguments"] = json.dumps(tool_call["args"])
|
||||
new_content.append(new_block)
|
||||
elif (
|
||||
is_data_content_block(block)
|
||||
and block["type"] == "image"
|
||||
and "base64" in block
|
||||
):
|
||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||
for extra_key in ("id", "status"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key]
|
||||
new_content.append(new_block)
|
||||
elif block_type == "non_standard" and "value" in block:
|
||||
new_content.append(block["value"])
|
||||
else:
|
||||
new_content.append(block)
|
||||
new_content.append(new_block)
|
||||
elif (
|
||||
is_data_content_block(cast(dict, block))
|
||||
and block["type"] == "image"
|
||||
and "base64" in block
|
||||
and isinstance(block.get("id"), str)
|
||||
and block["id"].startswith("ig_")
|
||||
):
|
||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||
for extra_key in ("id", "status"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
|
||||
new_content.append(new_block)
|
||||
elif block["type"] == "non_standard" and "value" in block:
|
||||
new_content.append(block["value"])
|
||||
else:
|
||||
new_content.append(block)
|
||||
|
||||
@ -679,4 +690,4 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
||||
)
|
||||
)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
return new_content
|
||||
|
3813
libs/partners/openai/langchain_openai/chat_models/base_v1.py
Normal file
3813
libs/partners/openai/langchain_openai/chat_models/base_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -56,6 +56,8 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
disable_error_code = ["typeddict-unknown-key"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers"
|
||||
ignore_missing_imports = true
|
||||
|
Binary file not shown.
@ -14,16 +14,24 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
MessageLikeRepresentation,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import ChatOpenAI, ChatOpenAIV1
|
||||
|
||||
MODEL_NAME = "gpt-4o-mini"
|
||||
|
||||
|
||||
def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
assert isinstance(response, AIMessage)
|
||||
def _check_response(response: Optional[BaseMessage], output_version) -> None:
|
||||
if output_version == "v1":
|
||||
assert isinstance(response, AIMessageV1) or isinstance(
|
||||
response, AIMessageChunkV1
|
||||
)
|
||||
else:
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
for block in response.content:
|
||||
assert isinstance(block, dict)
|
||||
@ -41,7 +49,10 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
for key in ["end_index", "start_index", "title", "type", "url"]
|
||||
)
|
||||
|
||||
text_content = response.text()
|
||||
if output_version == "v1":
|
||||
text_content = response.text
|
||||
else:
|
||||
text_content = response.text()
|
||||
assert isinstance(text_content, str)
|
||||
assert text_content
|
||||
assert response.usage_metadata
|
||||
@ -56,22 +67,34 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
||||
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
||||
if output_version == "v1":
|
||||
llm = ChatOpenAIV1(model=MODEL_NAME)
|
||||
else:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
||||
first_response = llm.invoke(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
)
|
||||
_check_response(first_response)
|
||||
_check_response(first_response, output_version)
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
_check_response(full)
|
||||
if isinstance(llm, ChatOpenAIV1):
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm.stream(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
else:
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream(
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
_check_response(full, output_version)
|
||||
|
||||
# Use OpenAI's stateful API
|
||||
response = llm.invoke(
|
||||
@ -79,38 +102,26 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
previous_response_id=first_response.response_metadata["id"],
|
||||
)
|
||||
_check_response(response)
|
||||
_check_response(response, output_version)
|
||||
|
||||
# Manually pass in chat history
|
||||
response = llm.invoke(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What was a positive news story from today?",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "What was a positive news story from today?"},
|
||||
first_response,
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "what about a negative one"}],
|
||||
},
|
||||
{"role": "user", "content": "what about a negative one"},
|
||||
],
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
)
|
||||
_check_response(response)
|
||||
_check_response(response, output_version)
|
||||
|
||||
# Bind tool
|
||||
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
|
||||
"What was a positive news story from today?"
|
||||
)
|
||||
_check_response(response)
|
||||
_check_response(response, output_version)
|
||||
|
||||
for msg in [first_response, full, response]:
|
||||
assert isinstance(msg, AIMessage)
|
||||
block_types = [block["type"] for block in msg.content] # type: ignore[index]
|
||||
if output_version == "responses/v1":
|
||||
assert block_types == ["web_search_call", "text"]
|
||||
@ -125,7 +136,7 @@ async def test_web_search_async() -> None:
|
||||
"What was a positive news story from today?",
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
)
|
||||
_check_response(response)
|
||||
_check_response(response, "v0")
|
||||
assert response.response_metadata["status"]
|
||||
|
||||
# Test streaming
|
||||
@ -137,7 +148,7 @@ async def test_web_search_async() -> None:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
_check_response(full)
|
||||
_check_response(full, "v0")
|
||||
|
||||
for msg in [response, full]:
|
||||
assert msg.additional_kwargs["tool_outputs"]
|
||||
@ -148,8 +159,8 @@ async def test_web_search_async() -> None:
|
||||
|
||||
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_function_calling(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""return x * y"""
|
||||
return x * y
|
||||
@ -170,7 +181,33 @@ def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -
|
||||
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
|
||||
|
||||
response = bound_llm.invoke("What was a positive news story from today?")
|
||||
_check_response(response)
|
||||
_check_response(response, output_version)
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_function_calling_v1() -> None:
|
||||
def multiply(x: int, y: int) -> int:
|
||||
"""return x * y"""
|
||||
return x * y
|
||||
|
||||
llm = ChatOpenAIV1(model=MODEL_NAME)
|
||||
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
|
||||
ai_msg = bound_llm.invoke("whats 5 * 4")
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
assert ai_msg.tool_calls[0]["name"] == "multiply"
|
||||
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
|
||||
|
||||
full: Any = None
|
||||
for chunk in bound_llm.stream("whats 5 * 4"):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert len(full.tool_calls) == 1
|
||||
assert full.tool_calls[0]["name"] == "multiply"
|
||||
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
|
||||
|
||||
response = bound_llm.invoke("What was a positive news story from today?")
|
||||
_check_response(response, "v1")
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
@ -183,10 +220,8 @@ class FooDict(TypedDict):
|
||||
|
||||
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_parsed_pydantic_schema(
|
||||
output_version: Literal["v0", "responses/v1", "v1"],
|
||||
) -> None:
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_parsed_pydantic_schema(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model=MODEL_NAME, use_responses_api=True, output_version=output_version
|
||||
)
|
||||
@ -206,6 +241,28 @@ def test_parsed_pydantic_schema(
|
||||
assert parsed.response
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_parsed_pydantic_schema_v1() -> None:
|
||||
llm = ChatOpenAIV1(model=MODEL_NAME, use_responses_api=True)
|
||||
response = llm.invoke("how are ya", response_format=Foo)
|
||||
parsed = Foo(**json.loads(response.text))
|
||||
assert parsed == response.parsed
|
||||
assert parsed.response
|
||||
|
||||
# Test stream
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
chunks = []
|
||||
for chunk in llm.stream("how are ya", response_format=Foo):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
chunks.append(chunk)
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
parsed = Foo(**json.loads(full.text))
|
||||
assert parsed == full.parsed
|
||||
assert parsed.response
|
||||
|
||||
|
||||
async def test_parsed_pydantic_schema_async() -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
response = await llm.ainvoke("how are ya", response_format=Foo)
|
||||
@ -311,8 +368,8 @@ def test_function_calling_and_structured_output() -> None:
|
||||
|
||||
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
@ -337,6 +394,26 @@ def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
assert block_types == ["reasoning", "text"]
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_reasoning_v1() -> None:
|
||||
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
|
||||
response = llm.invoke("Hello", reasoning={"effort": "low"})
|
||||
assert isinstance(response, AIMessageV1)
|
||||
|
||||
# Test init params + streaming
|
||||
llm = ChatOpenAIV1(model="o4-mini", reasoning={"effort": "low"})
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm.stream("Hello"):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
|
||||
for msg in [response, full]:
|
||||
block_types = [block["type"] for block in msg.content]
|
||||
assert block_types == ["reasoning", "text"]
|
||||
|
||||
|
||||
def test_stateful_api() -> None:
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
response = llm.invoke("how are you, my name is Bobo")
|
||||
@ -380,14 +457,14 @@ def test_file_search() -> None:
|
||||
|
||||
input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
|
||||
response = llm.invoke([input_message], tools=[tool])
|
||||
_check_response(response)
|
||||
_check_response(response, "v0")
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream([input_message], tools=[tool]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
_check_response(full)
|
||||
_check_response(full, "v0")
|
||||
|
||||
next_message = {"role": "user", "content": "Thank you."}
|
||||
_ = llm.invoke([input_message, full, next_message])
|
||||
@ -395,9 +472,9 @@ def test_file_search() -> None:
|
||||
|
||||
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_stream_reasoning_summary(
|
||||
output_version: Literal["v0", "responses/v1", "v1"],
|
||||
output_version: Literal["v0", "responses/v1"],
|
||||
) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini",
|
||||
@ -424,7 +501,8 @@ def test_stream_reasoning_summary(
|
||||
assert isinstance(block["type"], str)
|
||||
assert isinstance(block["text"], str)
|
||||
assert block["text"]
|
||||
elif output_version == "responses/v1":
|
||||
else:
|
||||
# output_version == "responses/v1"
|
||||
reasoning = next(
|
||||
block
|
||||
for block in response_1.content
|
||||
@ -438,18 +516,6 @@ def test_stream_reasoning_summary(
|
||||
assert isinstance(block["type"], str)
|
||||
assert isinstance(block["text"], str)
|
||||
assert block["text"]
|
||||
else:
|
||||
# v1
|
||||
total_reasoning_blocks = 0
|
||||
for block in response_1.content:
|
||||
if block["type"] == "reasoning":
|
||||
total_reasoning_blocks += 1
|
||||
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
|
||||
assert isinstance(block["reasoning"], str)
|
||||
assert isinstance(block["index"], int)
|
||||
assert (
|
||||
total_reasoning_blocks > 1
|
||||
) # This query typically generates multiple reasoning blocks
|
||||
|
||||
# Check we can pass back summaries
|
||||
message_2 = {"role": "user", "content": "Thank you."}
|
||||
@ -457,10 +523,45 @@ def test_stream_reasoning_summary(
|
||||
assert isinstance(response_2, AIMessage)
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_stream_reasoning_summary_v1() -> None:
|
||||
llm = ChatOpenAIV1(
|
||||
model="o4-mini",
|
||||
# Routes to Responses API if `reasoning` is set.
|
||||
reasoning={"effort": "medium", "summary": "auto"},
|
||||
)
|
||||
message_1 = {
|
||||
"role": "user",
|
||||
"content": "What was the third tallest buliding in the year 2000?",
|
||||
}
|
||||
response_1: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm.stream([message_1]):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
response_1 = chunk if response_1 is None else response_1 + chunk
|
||||
assert isinstance(response_1, AIMessageChunkV1)
|
||||
|
||||
total_reasoning_blocks = 0
|
||||
for block in response_1.content:
|
||||
if block["type"] == "reasoning":
|
||||
total_reasoning_blocks += 1
|
||||
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
|
||||
assert isinstance(block["reasoning"], str)
|
||||
assert isinstance(block["index"], int)
|
||||
assert (
|
||||
total_reasoning_blocks > 1
|
||||
) # This query typically generates multiple reasoning blocks
|
||||
|
||||
# Check we can pass back summaries
|
||||
message_2 = {"role": "user", "content": "Thank you."}
|
||||
response_2 = llm.invoke([message_1, response_1, message_2])
|
||||
assert isinstance(response_2, AIMessageV1)
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_code_interpreter(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||
)
|
||||
@ -473,33 +574,20 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
||||
}
|
||||
response = llm_with_tools.invoke([input_message])
|
||||
assert isinstance(response, AIMessage)
|
||||
_check_response(response)
|
||||
_check_response(response, output_version)
|
||||
if output_version == "v0":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.additional_kwargs["tool_outputs"]
|
||||
if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
elif output_version == "responses/v1":
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
else:
|
||||
# v1
|
||||
# responses/v1
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert tool_outputs
|
||||
assert code_interpreter_result
|
||||
assert len(tool_outputs) == 1
|
||||
|
||||
# Test streaming
|
||||
@ -520,25 +608,65 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
||||
for item in response.additional_kwargs["tool_outputs"]
|
||||
if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
elif output_version == "responses/v1":
|
||||
else:
|
||||
# responses/v1
|
||||
tool_outputs = [
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
]
|
||||
else:
|
||||
code_interpreter_call = next(
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||
)
|
||||
code_interpreter_result = next(
|
||||
item
|
||||
for item in response.content
|
||||
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert code_interpreter_call
|
||||
assert code_interpreter_result
|
||||
assert tool_outputs
|
||||
|
||||
# Test we can pass back in
|
||||
next_message = {"role": "user", "content": "Please add more comments to the code."}
|
||||
_ = llm_with_tools.invoke([input_message, full, next_message])
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_code_interpreter_v1() -> None:
|
||||
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "code_interpreter", "container": {"type": "auto"}}]
|
||||
)
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Write and run code to answer the question: what is 3^3?",
|
||||
}
|
||||
response = llm_with_tools.invoke([input_message])
|
||||
assert isinstance(response, AIMessageV1)
|
||||
_check_response(response, "v1")
|
||||
|
||||
tool_outputs = [
|
||||
item for item in response.content if item["type"] == "code_interpreter_call"
|
||||
]
|
||||
code_interpreter_result = next(
|
||||
item for item in response.content if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert tool_outputs
|
||||
assert code_interpreter_result
|
||||
assert len(tool_outputs) == 1
|
||||
|
||||
# Test streaming
|
||||
# Use same container
|
||||
container_id = tool_outputs[0]["container_id"]
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "code_interpreter", "container": container_id}]
|
||||
)
|
||||
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm_with_tools.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
code_interpreter_call = next(
|
||||
item for item in full.content if item["type"] == "code_interpreter_call"
|
||||
)
|
||||
code_interpreter_result = next(
|
||||
item for item in full.content if item["type"] == "code_interpreter_result"
|
||||
)
|
||||
assert code_interpreter_call
|
||||
assert code_interpreter_result
|
||||
assert tool_outputs
|
||||
|
||||
# Test we can pass back in
|
||||
@ -634,9 +762,59 @@ def test_mcp_builtin_zdr() -> None:
|
||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_mcp_builtin_zdr_v1() -> None:
|
||||
llm = ChatOpenAIV1(
|
||||
model="o4-mini", store=False, include=["reasoning.encrypted_content"]
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "deepwiki",
|
||||
"server_url": "https://mcp.deepwiki.com/mcp",
|
||||
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
|
||||
}
|
||||
]
|
||||
)
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"What transport protocols does the 2025-03-26 version of the MCP spec "
|
||||
"support?"
|
||||
),
|
||||
}
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm_with_tools.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
assert all(isinstance(block, dict) for block in full.content)
|
||||
|
||||
approval_message = HumanMessageV1(
|
||||
[
|
||||
{
|
||||
"type": "non_standard",
|
||||
"value": {
|
||||
"type": "mcp_approval_response",
|
||||
"approve": True,
|
||||
"approval_request_id": block["value"]["id"], # type: ignore[index]
|
||||
},
|
||||
}
|
||||
for block in full.content
|
||||
if block["type"] == "non_standard"
|
||||
and block["value"]["type"] == "mcp_approval_request" # type: ignore[index]
|
||||
]
|
||||
)
|
||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_image_generation_streaming(output_version: str) -> None:
|
||||
"""Test image generation streaming."""
|
||||
llm = ChatOpenAI(
|
||||
@ -710,9 +888,52 @@ def test_image_generation_streaming(output_version: str) -> None:
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_image_generation_streaming_v1() -> None:
|
||||
"""Test image generation streaming."""
|
||||
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
"quality": "low",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": 100,
|
||||
"size": "1024x1024",
|
||||
}
|
||||
|
||||
expected_keys = {
|
||||
# Standard
|
||||
"type",
|
||||
"base64",
|
||||
"mime_type",
|
||||
"id",
|
||||
"index",
|
||||
# OpenAI-specific
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
}
|
||||
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
|
||||
assert isinstance(chunk, AIMessageChunkV1)
|
||||
full = chunk if full is None else full + chunk
|
||||
complete_ai_message = cast(AIMessageChunkV1, full)
|
||||
|
||||
tool_output = next(
|
||||
block
|
||||
for block in complete_ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(expected_keys).issubset(tool_output.keys())
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||
def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
"""Test multi-turn editing of image generation by passing in history."""
|
||||
# Test multi-turn
|
||||
@ -735,7 +956,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
]
|
||||
ai_message = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
_check_response(ai_message)
|
||||
_check_response(ai_message, output_version)
|
||||
|
||||
expected_keys = {
|
||||
"id",
|
||||
@ -801,7 +1022,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
|
||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message2, AIMessage)
|
||||
_check_response(ai_message2)
|
||||
_check_response(ai_message2, output_version)
|
||||
|
||||
if output_version == "v0":
|
||||
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||
@ -821,3 +1042,76 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_image_generation_multi_turn_v1() -> None:
|
||||
"""Test multi-turn editing of image generation by passing in history."""
|
||||
# Test multi-turn
|
||||
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
|
||||
# Test invocation
|
||||
tool = {
|
||||
"type": "image_generation",
|
||||
"quality": "low",
|
||||
"output_format": "jpeg",
|
||||
"output_compression": 100,
|
||||
"size": "1024x1024",
|
||||
}
|
||||
llm_with_tools = llm.bind_tools([tool])
|
||||
|
||||
chat_history: list[MessageLikeRepresentation] = [
|
||||
{"role": "user", "content": "Draw a random short word in green font."}
|
||||
]
|
||||
ai_message = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message, AIMessageV1)
|
||||
_check_response(ai_message, "v1")
|
||||
|
||||
expected_keys = {
|
||||
# Standard
|
||||
"type",
|
||||
"base64",
|
||||
"mime_type",
|
||||
"id",
|
||||
# OpenAI-specific
|
||||
"background",
|
||||
"output_format",
|
||||
"quality",
|
||||
"revised_prompt",
|
||||
"size",
|
||||
"status",
|
||||
}
|
||||
|
||||
standard_keys = {"type", "base64", "id", "status"}
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(standard_keys).issubset(tool_output.keys())
|
||||
|
||||
chat_history.extend(
|
||||
[
|
||||
# AI message with tool output
|
||||
ai_message,
|
||||
# New request
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Now, change the font to blue. Keep the word and everything else "
|
||||
"the same."
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||
assert isinstance(ai_message2, AIMessageV1)
|
||||
_check_response(ai_message2, "v1")
|
||||
|
||||
tool_output = next(
|
||||
block
|
||||
for block in ai_message2.content
|
||||
if isinstance(block, dict) and block["type"] == "image"
|
||||
)
|
||||
assert set(expected_keys).issubset(tool_output.keys())
|
||||
|
Loading…
Reference in New Issue
Block a user