mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 02:06:33 +00:00
feat(openai): v1 message format support (#32296)
This commit is contained in:
parent
7166adce1f
commit
c15e55b33c
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
from abc import ABC
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -26,7 +25,6 @@ from langchain_core.messages import (
|
|||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
get_buffer_string,
|
|
||||||
)
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
@ -166,7 +164,6 @@ class BaseLanguageModel(
|
|||||||
list[AnyMessage],
|
list[AnyMessage],
|
||||||
]
|
]
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self,
|
self,
|
||||||
prompts: list[PromptValue],
|
prompts: list[PromptValue],
|
||||||
@ -201,7 +198,6 @@ class BaseLanguageModel(
|
|||||||
prompt and additional model provider-specific output.
|
prompt and additional model provider-specific output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def agenerate_prompt(
|
async def agenerate_prompt(
|
||||||
self,
|
self,
|
||||||
prompts: list[PromptValue],
|
prompts: list[PromptValue],
|
||||||
@ -245,7 +241,6 @@ class BaseLanguageModel(
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
@abstractmethod
|
|
||||||
def predict(
|
def predict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -266,7 +261,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
@abstractmethod
|
|
||||||
def predict_messages(
|
def predict_messages(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -291,7 +285,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||||
@abstractmethod
|
|
||||||
async def apredict(
|
async def apredict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -312,7 +305,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||||
@abstractmethod
|
|
||||||
async def apredict_messages(
|
async def apredict_messages(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -368,33 +360,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
return len(self.get_token_ids(text))
|
return len(self.get_token_ids(text))
|
||||||
|
|
||||||
def get_num_tokens_from_messages(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
tools: Optional[Sequence] = None,
|
|
||||||
) -> int:
|
|
||||||
"""Get the number of tokens in the messages.
|
|
||||||
|
|
||||||
Useful for checking if an input fits in a model's context window.
|
|
||||||
|
|
||||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
|
||||||
tool schemas.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The message inputs to tokenize.
|
|
||||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
|
||||||
to be converted to tool schemas.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The sum of the number of tokens across the messages.
|
|
||||||
"""
|
|
||||||
if tools is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _all_required_field_names(cls) -> set:
|
def _all_required_field_names(cls) -> set:
|
||||||
"""DEPRECATED: Kept for backwards compatibility.
|
"""DEPRECATED: Kept for backwards compatibility.
|
||||||
|
@ -55,12 +55,11 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
|
get_buffer_string,
|
||||||
is_data_content_block,
|
is_data_content_block,
|
||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import content_blocks as types
|
|
||||||
from langchain_core.messages.ai import _LC_ID_PREFIX
|
from langchain_core.messages.ai import _LC_ID_PREFIX
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
@ -222,23 +221,6 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
|
|||||||
return ls_structured_output_format_dict
|
return ls_structured_output_format_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_v1(message: AIMessage) -> AIMessageV1:
|
|
||||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
content: list[types.ContentBlock] = []
|
|
||||||
if message.content:
|
|
||||||
content = [{"type": "text", "text": message.content}]
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
content.append(tool_call)
|
|
||||||
|
|
||||||
return AIMessageV1(
|
|
||||||
content=content,
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
response_metadata=message.response_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
"""Base class for chat models.
|
"""Base class for chat models.
|
||||||
|
|
||||||
@ -1370,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
starter_dict["_type"] = self._llm_type
|
starter_dict["_type"] = self._llm_type
|
||||||
return starter_dict
|
return starter_dict
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[
|
tools: Sequence[
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import typing
|
import typing
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -38,11 +39,14 @@ from langchain_core.language_models.base import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
|
get_buffer_string,
|
||||||
is_data_content_block,
|
is_data_content_block,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.utils import convert_to_messages_v1
|
from langchain_core.messages.utils import (
|
||||||
|
_convert_from_v1_message,
|
||||||
|
convert_to_messages_v1,
|
||||||
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
@ -735,7 +739,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
|||||||
*,
|
*,
|
||||||
tool_choice: Optional[Union[str]] = None,
|
tool_choice: Optional[Union[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
||||||
"""Bind tools to the model.
|
"""Bind tools to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -899,6 +903,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
|||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
messages_v0 = [_convert_from_v1_message(message) for message in messages]
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages_v0)
|
||||||
|
|
||||||
|
|
||||||
def _gen_info_and_msg_metadata(
|
def _gen_info_and_msg_metadata(
|
||||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||||
|
@ -706,6 +706,7 @@ ToolContentBlock = Union[
|
|||||||
ContentBlock = Union[
|
ContentBlock = Union[
|
||||||
TextContentBlock,
|
TextContentBlock,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
InvalidToolCall,
|
||||||
ReasoningContentBlock,
|
ReasoningContentBlock,
|
||||||
NonStandardContentBlock,
|
NonStandardContentBlock,
|
||||||
DataContentBlock,
|
DataContentBlock,
|
||||||
|
@ -384,38 +384,37 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
|||||||
Returns:
|
Returns:
|
||||||
BaseMessage: Converted message instance.
|
BaseMessage: Converted message instance.
|
||||||
"""
|
"""
|
||||||
# type ignores here are because AIMessageV1.content is a list of dicts.
|
content = cast("Union[str, list[str | dict]]", message.content)
|
||||||
# AIMessageV0.content expects str or list[str | dict].
|
|
||||||
if isinstance(message, AIMessageV1):
|
if isinstance(message, AIMessageV1):
|
||||||
return AIMessage(
|
return AIMessage(
|
||||||
content=message.content, # type: ignore[arg-type]
|
content=content,
|
||||||
id=message.id,
|
id=message.id,
|
||||||
name=message.name,
|
name=message.name,
|
||||||
tool_calls=message.tool_calls,
|
tool_calls=message.tool_calls,
|
||||||
response_metadata=message.response_metadata,
|
response_metadata=cast("dict", message.response_metadata),
|
||||||
)
|
)
|
||||||
if isinstance(message, AIMessageChunkV1):
|
if isinstance(message, AIMessageChunkV1):
|
||||||
return AIMessageChunk(
|
return AIMessageChunk(
|
||||||
content=message.content, # type: ignore[arg-type]
|
content=content,
|
||||||
id=message.id,
|
id=message.id,
|
||||||
name=message.name,
|
name=message.name,
|
||||||
tool_call_chunks=message.tool_call_chunks,
|
tool_call_chunks=message.tool_call_chunks,
|
||||||
response_metadata=message.response_metadata,
|
response_metadata=cast("dict", message.response_metadata),
|
||||||
)
|
)
|
||||||
if isinstance(message, HumanMessageV1):
|
if isinstance(message, HumanMessageV1):
|
||||||
return HumanMessage(
|
return HumanMessage(
|
||||||
content=message.content, # type: ignore[arg-type]
|
content=content,
|
||||||
id=message.id,
|
id=message.id,
|
||||||
name=message.name,
|
name=message.name,
|
||||||
)
|
)
|
||||||
if isinstance(message, SystemMessageV1):
|
if isinstance(message, SystemMessageV1):
|
||||||
return SystemMessage(
|
return SystemMessage(
|
||||||
content=message.content, # type: ignore[arg-type]
|
content=content,
|
||||||
id=message.id,
|
id=message.id,
|
||||||
)
|
)
|
||||||
if isinstance(message, ToolMessageV1):
|
if isinstance(message, ToolMessageV1):
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=message.content, # type: ignore[arg-type]
|
content=content,
|
||||||
id=message.id,
|
id=message.id,
|
||||||
)
|
)
|
||||||
message = f"Unsupported message type: {type(message)}"
|
message = f"Unsupported message type: {type(message)}"
|
||||||
@ -501,7 +500,10 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
|||||||
ValueError: if the message dict does not contain the required keys.
|
ValueError: if the message dict does not contain the required keys.
|
||||||
"""
|
"""
|
||||||
if isinstance(message, MessageV1Types):
|
if isinstance(message, MessageV1Types):
|
||||||
message_ = message
|
if isinstance(message, AIMessageChunkV1):
|
||||||
|
message_ = message.to_message()
|
||||||
|
else:
|
||||||
|
message_ = message
|
||||||
elif isinstance(message, str):
|
elif isinstance(message, str):
|
||||||
message_ = _create_message_from_message_type_v1("human", message)
|
message_ = _create_message_from_message_type_v1("human", message)
|
||||||
elif isinstance(message, Sequence) and len(message) == 2:
|
elif isinstance(message, Sequence) and len(message) == 2:
|
||||||
|
@ -5,6 +5,8 @@ import uuid
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
|
from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import langchain_core.messages.content_blocks as types
|
import langchain_core.messages.content_blocks as types
|
||||||
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage
|
||||||
from langchain_core.messages.base import merge_content
|
from langchain_core.messages.base import merge_content
|
||||||
@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str:
|
|||||||
return id_val or str(uuid.uuid4())
|
return id_val or str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
class Provider(TypedDict):
|
class ResponseMetadata(TypedDict, total=False):
|
||||||
"""Information about the provider that generated the message.
|
"""Metadata about the response from the AI provider.
|
||||||
|
|
||||||
Contains metadata about the AI provider and model used to generate content.
|
Contains additional information returned by the provider, such as
|
||||||
|
response headers, service tiers, log probabilities, system fingerprints, etc.
|
||||||
|
|
||||||
Attributes:
|
Extra keys are permitted from what is typed here.
|
||||||
name: Name and version of the provider that created the content block.
|
|
||||||
model_name: Name of the model that generated the content block.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
model_provider: str
|
||||||
"""Name and version of the provider that created the content block."""
|
"""Name and version of the provider that created the message (e.g., openai)."""
|
||||||
|
|
||||||
model_name: str
|
model_name: str
|
||||||
"""Name of the model that generated the content block."""
|
"""Name of the model that generated the message."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -91,21 +93,29 @@ class AIMessage:
|
|||||||
usage_metadata: Optional[UsageMetadata] = None
|
usage_metadata: Optional[UsageMetadata] = None
|
||||||
"""If provided, usage metadata for a message, such as token counts."""
|
"""If provided, usage metadata for a message, such as token counts."""
|
||||||
|
|
||||||
response_metadata: dict = field(default_factory=dict)
|
response_metadata: ResponseMetadata = field(
|
||||||
|
default_factory=lambda: ResponseMetadata()
|
||||||
|
)
|
||||||
"""Metadata about the response.
|
"""Metadata about the response.
|
||||||
|
|
||||||
This field should include non-standard data returned by the provider, such as
|
This field should include non-standard data returned by the provider, such as
|
||||||
response headers, service tiers, or log probabilities.
|
response headers, service tiers, or log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||||
|
"""Auto-parsed message contents, if applicable."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
content: Union[str, list[types.ContentBlock]],
|
content: Union[str, list[types.ContentBlock]],
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
lc_version: str = "v1",
|
lc_version: str = "v1",
|
||||||
response_metadata: Optional[dict] = None,
|
response_metadata: Optional[ResponseMetadata] = None,
|
||||||
usage_metadata: Optional[UsageMetadata] = None,
|
usage_metadata: Optional[UsageMetadata] = None,
|
||||||
|
tool_calls: Optional[list[types.ToolCall]] = None,
|
||||||
|
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
|
||||||
|
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize an AI message.
|
"""Initialize an AI message.
|
||||||
|
|
||||||
@ -116,6 +126,11 @@ class AIMessage:
|
|||||||
lc_version: Encoding version for the message.
|
lc_version: Encoding version for the message.
|
||||||
response_metadata: Optional metadata about the response.
|
response_metadata: Optional metadata about the response.
|
||||||
usage_metadata: Optional metadata about token usage.
|
usage_metadata: Optional metadata about token usage.
|
||||||
|
tool_calls: Optional list of tool calls made by the AI. Tool calls should
|
||||||
|
generally be included in message content. If passed on init, they will
|
||||||
|
be added to the content list.
|
||||||
|
invalid_tool_calls: Optional list of tool calls that failed validation.
|
||||||
|
parsed: Optional auto-parsed message contents, if applicable.
|
||||||
"""
|
"""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [{"type": "text", "text": content}]
|
self.content = [{"type": "text", "text": content}]
|
||||||
@ -126,13 +141,27 @@ class AIMessage:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.lc_version = lc_version
|
self.lc_version = lc_version
|
||||||
self.usage_metadata = usage_metadata
|
self.usage_metadata = usage_metadata
|
||||||
|
self.parsed = parsed
|
||||||
if response_metadata is None:
|
if response_metadata is None:
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
else:
|
else:
|
||||||
self.response_metadata = response_metadata
|
self.response_metadata = response_metadata
|
||||||
|
|
||||||
self._tool_calls: list[types.ToolCall] = []
|
# Add tool calls to content if provided on init
|
||||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
if tool_calls:
|
||||||
|
content_tool_calls = {
|
||||||
|
block["id"]
|
||||||
|
for block in self.content
|
||||||
|
if block["type"] == "tool_call" and "id" in block
|
||||||
|
}
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||||
|
continue
|
||||||
|
self.content.append(tool_call)
|
||||||
|
self._tool_calls = [
|
||||||
|
block for block in self.content if block["type"] == "tool_call"
|
||||||
|
]
|
||||||
|
self.invalid_tool_calls = invalid_tool_calls or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> Optional[str]:
|
||||||
@ -150,7 +179,7 @@ class AIMessage:
|
|||||||
tool_calls = [block for block in self.content if block["type"] == "tool_call"]
|
tool_calls = [block for block in self.content if block["type"] == "tool_call"]
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
self._tool_calls = tool_calls
|
self._tool_calls = tool_calls
|
||||||
return self._tool_calls
|
return [block for block in self.content if block["type"] == "tool_call"]
|
||||||
|
|
||||||
@tool_calls.setter
|
@tool_calls.setter
|
||||||
def tool_calls(self, value: list[types.ToolCall]) -> None:
|
def tool_calls(self, value: list[types.ToolCall]) -> None:
|
||||||
@ -202,13 +231,16 @@ class AIMessageChunk:
|
|||||||
These data represent incremental usage statistics, as opposed to a running total.
|
These data represent incremental usage statistics, as opposed to a running total.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response_metadata: dict = field(init=False)
|
response_metadata: ResponseMetadata = field(init=False)
|
||||||
"""Metadata about the response chunk.
|
"""Metadata about the response chunk.
|
||||||
|
|
||||||
This field should include non-standard data returned by the provider, such as
|
This field should include non-standard data returned by the provider, such as
|
||||||
response headers, service tiers, or log probabilities.
|
response headers, service tiers, or log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||||
|
"""Auto-parsed message contents, if applicable."""
|
||||||
|
|
||||||
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -217,9 +249,10 @@ class AIMessageChunk:
|
|||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
lc_version: str = "v1",
|
lc_version: str = "v1",
|
||||||
response_metadata: Optional[dict] = None,
|
response_metadata: Optional[ResponseMetadata] = None,
|
||||||
usage_metadata: Optional[UsageMetadata] = None,
|
usage_metadata: Optional[UsageMetadata] = None,
|
||||||
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
||||||
|
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize an AI message.
|
"""Initialize an AI message.
|
||||||
|
|
||||||
@ -231,6 +264,7 @@ class AIMessageChunk:
|
|||||||
response_metadata: Optional metadata about the response.
|
response_metadata: Optional metadata about the response.
|
||||||
usage_metadata: Optional metadata about token usage.
|
usage_metadata: Optional metadata about token usage.
|
||||||
tool_call_chunks: Optional list of partial tool call data.
|
tool_call_chunks: Optional list of partial tool call data.
|
||||||
|
parsed: Optional auto-parsed message contents, if applicable.
|
||||||
"""
|
"""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [{"type": "text", "text": content, "index": 0}]
|
self.content = [{"type": "text", "text": content, "index": 0}]
|
||||||
@ -241,6 +275,7 @@ class AIMessageChunk:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.lc_version = lc_version
|
self.lc_version = lc_version
|
||||||
self.usage_metadata = usage_metadata
|
self.usage_metadata = usage_metadata
|
||||||
|
self.parsed = parsed
|
||||||
if response_metadata is None:
|
if response_metadata is None:
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
else:
|
else:
|
||||||
@ -251,7 +286,7 @@ class AIMessageChunk:
|
|||||||
self.tool_call_chunks = tool_call_chunks
|
self.tool_call_chunks = tool_call_chunks
|
||||||
|
|
||||||
self._tool_calls: list[types.ToolCall] = []
|
self._tool_calls: list[types.ToolCall] = []
|
||||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
self.invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||||
self._init_tool_calls()
|
self._init_tool_calls()
|
||||||
|
|
||||||
def _init_tool_calls(self) -> None:
|
def _init_tool_calls(self) -> None:
|
||||||
@ -264,7 +299,7 @@ class AIMessageChunk:
|
|||||||
ValueError: If the tool call chunks are malformed.
|
ValueError: If the tool call chunks are malformed.
|
||||||
"""
|
"""
|
||||||
self._tool_calls = []
|
self._tool_calls = []
|
||||||
self._invalid_tool_calls = []
|
self.invalid_tool_calls = []
|
||||||
if not self.tool_call_chunks:
|
if not self.tool_call_chunks:
|
||||||
if self._tool_calls:
|
if self._tool_calls:
|
||||||
self.tool_call_chunks = [
|
self.tool_call_chunks = [
|
||||||
@ -276,14 +311,14 @@ class AIMessageChunk:
|
|||||||
)
|
)
|
||||||
for tc in self._tool_calls
|
for tc in self._tool_calls
|
||||||
]
|
]
|
||||||
if self._invalid_tool_calls:
|
if self.invalid_tool_calls:
|
||||||
tool_call_chunks = self.tool_call_chunks
|
tool_call_chunks = self.tool_call_chunks
|
||||||
tool_call_chunks.extend(
|
tool_call_chunks.extend(
|
||||||
[
|
[
|
||||||
create_tool_call_chunk(
|
create_tool_call_chunk(
|
||||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||||
)
|
)
|
||||||
for tc in self._invalid_tool_calls
|
for tc in self.invalid_tool_calls
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.tool_call_chunks = tool_call_chunks
|
self.tool_call_chunks = tool_call_chunks
|
||||||
@ -294,9 +329,9 @@ class AIMessageChunk:
|
|||||||
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
|
def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
|
||||||
invalid_tool_calls.append(
|
invalid_tool_calls.append(
|
||||||
create_invalid_tool_call(
|
create_invalid_tool_call(
|
||||||
name=chunk["name"],
|
name=chunk.get("name", ""),
|
||||||
args=chunk["args"],
|
args=chunk.get("args", ""),
|
||||||
id=chunk["id"],
|
id=chunk.get("id", ""),
|
||||||
error=None,
|
error=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -307,9 +342,9 @@ class AIMessageChunk:
|
|||||||
if isinstance(args_, dict):
|
if isinstance(args_, dict):
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
create_tool_call(
|
create_tool_call(
|
||||||
name=chunk["name"] or "",
|
name=chunk.get("name", ""),
|
||||||
args=args_,
|
args=args_,
|
||||||
id=chunk["id"],
|
id=chunk.get("id", ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -317,7 +352,7 @@ class AIMessageChunk:
|
|||||||
except Exception:
|
except Exception:
|
||||||
add_chunk_to_invalid_tool_calls(chunk)
|
add_chunk_to_invalid_tool_calls(chunk)
|
||||||
self._tool_calls = tool_calls
|
self._tool_calls = tool_calls
|
||||||
self._invalid_tool_calls = invalid_tool_calls
|
self.invalid_tool_calls = invalid_tool_calls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> Optional[str]:
|
||||||
@ -361,6 +396,20 @@ class AIMessageChunk:
|
|||||||
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
|
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
|
||||||
raise NotImplementedError(error_msg)
|
raise NotImplementedError(error_msg)
|
||||||
|
|
||||||
|
def to_message(self) -> "AIMessage":
|
||||||
|
"""Convert this AIMessageChunk to an AIMessage."""
|
||||||
|
return AIMessage(
|
||||||
|
content=self.content,
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
lc_version=self.lc_version,
|
||||||
|
response_metadata=self.response_metadata,
|
||||||
|
usage_metadata=self.usage_metadata,
|
||||||
|
tool_calls=self.tool_calls,
|
||||||
|
invalid_tool_calls=self.invalid_tool_calls,
|
||||||
|
parsed=self.parsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_ai_message_chunks(
|
def add_ai_message_chunks(
|
||||||
left: AIMessageChunk, *others: AIMessageChunk
|
left: AIMessageChunk, *others: AIMessageChunk
|
||||||
@ -373,7 +422,8 @@ def add_ai_message_chunks(
|
|||||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||||
)
|
)
|
||||||
response_metadata = merge_dicts(
|
response_metadata = merge_dicts(
|
||||||
left.response_metadata, *(o.response_metadata for o in others)
|
cast("dict", left.response_metadata),
|
||||||
|
*(cast("dict", o.response_metadata) for o in others),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge tool call chunks
|
# Merge tool call chunks
|
||||||
@ -400,6 +450,15 @@ def add_ai_message_chunks(
|
|||||||
else:
|
else:
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
|
|
||||||
|
# Parsed
|
||||||
|
# 'parsed' always represents an aggregation not an incremental value, so the last
|
||||||
|
# non-null value is kept.
|
||||||
|
parsed = None
|
||||||
|
for m in reversed([left, *others]):
|
||||||
|
if m.parsed is not None:
|
||||||
|
parsed = m.parsed
|
||||||
|
break
|
||||||
|
|
||||||
chunk_id = None
|
chunk_id = None
|
||||||
candidates = [left.id] + [o.id for o in others]
|
candidates = [left.id] + [o.id for o in others]
|
||||||
# first pass: pick the first non-run-* id
|
# first pass: pick the first non-run-* id
|
||||||
@ -417,8 +476,9 @@ def add_ai_message_chunks(
|
|||||||
return left.__class__(
|
return left.__class__(
|
||||||
content=cast("list[types.ContentBlock]", content),
|
content=cast("list[types.ContentBlock]", content),
|
||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
response_metadata=response_metadata,
|
response_metadata=cast("ResponseMetadata", response_metadata),
|
||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
|
parsed=parsed,
|
||||||
id=chunk_id,
|
id=chunk_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -455,19 +515,25 @@ class HumanMessage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None
|
self,
|
||||||
|
content: Union[str, list[types.ContentBlock]],
|
||||||
|
*,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize a human message.
|
"""Initialize a human message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Message content as string or list of content blocks.
|
content: Message content as string or list of content blocks.
|
||||||
id: Optional unique identifier for the message.
|
id: Optional unique identifier for the message.
|
||||||
|
name: Optional human-readable name for the message.
|
||||||
"""
|
"""
|
||||||
self.id = _ensure_id(id)
|
self.id = _ensure_id(id)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [{"type": "text", "text": content}]
|
self.content = [{"type": "text", "text": content}]
|
||||||
else:
|
else:
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.name = name
|
||||||
|
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the message.
|
"""Extract all text content from the message.
|
||||||
@ -497,20 +563,47 @@ class SystemMessage:
|
|||||||
content: list[types.ContentBlock]
|
content: list[types.ContentBlock]
|
||||||
type: Literal["system"] = "system"
|
type: Literal["system"] = "system"
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
"""An optional name for the message.
|
||||||
|
|
||||||
|
This can be used to provide a human-readable name for the message.
|
||||||
|
|
||||||
|
Usage of this field is optional, and whether it's used or not is up to the
|
||||||
|
model implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
custom_role: Optional[str] = None
|
||||||
|
"""If provided, a custom role for the system message.
|
||||||
|
|
||||||
|
Example: ``"developer"``.
|
||||||
|
|
||||||
|
Integration packages may use this field to assign the system message role if it
|
||||||
|
contains a recognized value.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None
|
self,
|
||||||
|
content: Union[str, list[types.ContentBlock]],
|
||||||
|
*,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
custom_role: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize a system message.
|
"""Initialize a human message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: System instructions as string or list of content blocks.
|
content: Message content as string or list of content blocks.
|
||||||
id: Optional unique identifier for the message.
|
id: Optional unique identifier for the message.
|
||||||
|
custom_role: If provided, a custom role for the system message.
|
||||||
|
name: Optional human-readable name for the message.
|
||||||
"""
|
"""
|
||||||
self.id = _ensure_id(id)
|
self.id = _ensure_id(id)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
self.content = [{"type": "text", "text": content}]
|
self.content = [{"type": "text", "text": content}]
|
||||||
else:
|
else:
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.custom_role = custom_role
|
||||||
|
self.name = name
|
||||||
|
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the system message."""
|
"""Extract all text content from the system message."""
|
||||||
@ -537,11 +630,51 @@ class ToolMessage:
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
content: list[dict[str, Any]]
|
content: list[types.ContentBlock]
|
||||||
artifact: Optional[Any] = None # App-side payload not for the model
|
artifact: Optional[Any] = None # App-side payload not for the model
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
"""An optional name for the message.
|
||||||
|
|
||||||
|
This can be used to provide a human-readable name for the message.
|
||||||
|
|
||||||
|
Usage of this field is optional, and whether it's used or not is up to the
|
||||||
|
model implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
status: Literal["success", "error"] = "success"
|
status: Literal["success", "error"] = "success"
|
||||||
type: Literal["tool"] = "tool"
|
type: Literal["tool"] = "tool"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Union[str, list[types.ContentBlock]],
|
||||||
|
tool_call_id: str,
|
||||||
|
*,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
artifact: Optional[Any] = None,
|
||||||
|
status: Literal["success", "error"] = "success",
|
||||||
|
):
|
||||||
|
"""Initialize a human message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content as string or list of content blocks.
|
||||||
|
tool_call_id: ID of the tool call this message responds to.
|
||||||
|
id: Optional unique identifier for the message.
|
||||||
|
name: Optional human-readable name for the message.
|
||||||
|
artifact: Optional app-side payload not intended for the model.
|
||||||
|
status: Execution status ("success" or "error").
|
||||||
|
"""
|
||||||
|
self.id = _ensure_id(id)
|
||||||
|
self.tool_call_id = tool_call_id
|
||||||
|
if isinstance(content, str):
|
||||||
|
self.content = [{"type": "text", "text": content}]
|
||||||
|
else:
|
||||||
|
self.content = content
|
||||||
|
self.name = name
|
||||||
|
self.artifact = artifact
|
||||||
|
self.status = status
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""Extract all text content from the tool message."""
|
"""Extract all text content from the tool message."""
|
||||||
|
@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
|
|||||||
from pydantic import SkipValidation, ValidationError
|
from pydantic import SkipValidation, ValidationError
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
|
||||||
from langchain_core.messages.tool import invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
@ -26,7 +26,7 @@ def parse_tool_call(
|
|||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
return_id: bool = True,
|
return_id: bool = True,
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> Optional[ToolCall]:
|
||||||
"""Parse a single tool call.
|
"""Parse a single tool call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -8,17 +8,65 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal, cast
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict, overload
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import content_blocks as types
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
|
||||||
|
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||||
|
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_v1(message: BaseMessage) -> MessageV1:
|
||||||
|
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
content: list[types.ContentBlock] = []
|
||||||
|
if message.content:
|
||||||
|
content = [{"type": "text", "text": message.content}]
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
content.append({"type": "text", "text": block})
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
content.append(cast("types.ContentBlock", block))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
return HumanMessageV1(content=content)
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
content.append(tool_call)
|
||||||
|
return AIMessageV1(
|
||||||
|
content=content,
|
||||||
|
usage_metadata=message.usage_metadata,
|
||||||
|
response_metadata=cast("ResponseMetadata", message.response_metadata),
|
||||||
|
tool_calls=message.tool_calls,
|
||||||
|
)
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
return SystemMessageV1(content=content)
|
||||||
|
if isinstance(message, ToolMessage):
|
||||||
|
return ToolMessageV1(
|
||||||
|
tool_call_id=message.tool_call_id,
|
||||||
|
content=content,
|
||||||
|
artifact=message.artifact,
|
||||||
|
)
|
||||||
|
error_message = f"Unsupported message type: {type(message)}"
|
||||||
|
raise TypeError(error_message)
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(Serializable, ABC):
|
class PromptValue(Serializable, ABC):
|
||||||
@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
|
|||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
"""Return prompt value as string."""
|
"""Return prompt value as string."""
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_messages(self) -> list[BaseMessage]:
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||||
"""Return prompt as a list of Messages."""
|
"""Return prompt as a list of Messages."""
|
||||||
|
|
||||||
|
|
||||||
@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
|
|||||||
"""Return prompt as string."""
|
"""Return prompt as string."""
|
||||||
return self.text
|
return self.text
|
||||||
|
|
||||||
def to_messages(self) -> list[BaseMessage]:
|
@overload
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
|
if output_version == "v1":
|
||||||
|
return [HumanMessageV1(content=self.text)]
|
||||||
return [HumanMessage(content=self.text)]
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
|
|||||||
"""Return prompt as string."""
|
"""Return prompt as string."""
|
||||||
return get_buffer_string(self.messages)
|
return get_buffer_string(self.messages)
|
||||||
|
|
||||||
def to_messages(self) -> list[BaseMessage]:
|
@overload
|
||||||
"""Return prompt as a list of messages."""
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||||
|
"""Return prompt as a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_version: The output version, either "v0" (default) or "v1".
|
||||||
|
"""
|
||||||
|
if output_version == "v1":
|
||||||
|
return [_convert_to_v1(m) for m in self.messages]
|
||||||
return list(self.messages)
|
return list(self.messages)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
|
|||||||
"""Return prompt (image URL) as string."""
|
"""Return prompt (image URL) as string."""
|
||||||
return self.image_url["url"]
|
return self.image_url["url"]
|
||||||
|
|
||||||
def to_messages(self) -> list[BaseMessage]:
|
@overload
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0"] = "v0"
|
||||||
|
) -> list[BaseMessage]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||||
|
|
||||||
|
def to_messages(
|
||||||
|
self, output_version: Literal["v0", "v1"] = "v0"
|
||||||
|
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||||
"""Return prompt (image URL) as messages."""
|
"""Return prompt (image URL) as messages."""
|
||||||
|
if output_version == "v1":
|
||||||
|
block: types.ImageContentBlock = {
|
||||||
|
"type": "image",
|
||||||
|
"url": self.image_url["url"],
|
||||||
|
}
|
||||||
|
if "detail" in self.image_url:
|
||||||
|
block["detail"] = self.image_url["detail"]
|
||||||
|
return [HumanMessageV1(content=[block])]
|
||||||
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
|
|||||||
strict = "True"
|
strict = "True"
|
||||||
strict_bytes = "True"
|
strict_bytes = "True"
|
||||||
enable_error_code = "deprecated"
|
enable_error_code = "deprecated"
|
||||||
|
disable_error_code = ["typeddict-unknown-key"]
|
||||||
|
|
||||||
# TODO: activate for 'strict' checking
|
# TODO: activate for 'strict' checking
|
||||||
disallow_any_generics = "False"
|
disallow_any_generics = "False"
|
||||||
|
@ -34,6 +34,7 @@ from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES
|
|||||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.utils._merge import merge_lists
|
from langchain_core.utils._merge import merge_lists
|
||||||
|
|
||||||
@ -197,7 +198,7 @@ def test_message_chunks() -> None:
|
|||||||
assert (meaningful_id + default_id).id == "msg_def456"
|
assert (meaningful_id + default_id).id == "msg_def456"
|
||||||
|
|
||||||
|
|
||||||
def test_message_chunks_v2() -> None:
|
def test_message_chunks_v1() -> None:
|
||||||
left = AIMessageChunkV1("foo ", id="abc")
|
left = AIMessageChunkV1("foo ", id="abc")
|
||||||
right = AIMessageChunkV1("bar")
|
right = AIMessageChunkV1("bar")
|
||||||
expected = AIMessageChunkV1("foo bar", id="abc")
|
expected = AIMessageChunkV1("foo bar", id="abc")
|
||||||
@ -230,7 +231,19 @@ def test_message_chunks_v2() -> None:
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert one + two + three == expected
|
result = one + two + three
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
assert result.to_message() == AIMessageV1(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"name": "tool1",
|
||||||
|
"args": {"arg1": "value}"},
|
||||||
|
"id": "1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
AIMessageChunkV1(
|
AIMessageChunkV1(
|
||||||
@ -1326,6 +1339,7 @@ def test_known_block_types() -> None:
|
|||||||
"text",
|
"text",
|
||||||
"text-plain",
|
"text-plain",
|
||||||
"tool_call",
|
"tool_call",
|
||||||
|
"invalid_tool_call",
|
||||||
"reasoning",
|
"reasoning",
|
||||||
"non_standard",
|
"non_standard",
|
||||||
"image",
|
"image",
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
|
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, ChatOpenAIV1
|
||||||
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
"ChatOpenAIV1",
|
||||||
"OpenAIEmbeddings",
|
"OpenAIEmbeddings",
|
||||||
"AzureOpenAI",
|
"AzureOpenAI",
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from langchain_openai.chat_models.azure import AzureChatOpenAI
|
from langchain_openai.chat_models.azure import AzureChatOpenAI
|
||||||
from langchain_openai.chat_models.base import ChatOpenAI
|
from langchain_openai.chat_models.base import ChatOpenAI
|
||||||
|
from langchain_openai.chat_models.base_v1 import ChatOpenAI as ChatOpenAIV1
|
||||||
|
|
||||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]
|
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "ChatOpenAIV1"]
|
||||||
|
@ -66,11 +66,14 @@ For backwards compatibility, this module provides functions to convert between t
|
|||||||
formats. The functions are used internally by ChatOpenAI.
|
formats. The functions are used internally by ChatOpenAI.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from typing import Any, Literal, Union, cast
|
from typing import Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
|
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
|
||||||
|
from langchain_core.messages import content_blocks as types
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
|
||||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||||
|
|
||||||
@ -289,25 +292,21 @@ def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessa
|
|||||||
return cast(AIMessageChunk, result)
|
return cast(AIMessageChunk, result)
|
||||||
|
|
||||||
|
|
||||||
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
|
def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1:
|
||||||
"""Convert a v1 message to the Chat Completions format."""
|
"""Convert a v1 message to the Chat Completions format."""
|
||||||
if isinstance(message.content, list):
|
new_content: list[types.ContentBlock] = []
|
||||||
new_content: list = []
|
for block in message.content:
|
||||||
for block in message.content:
|
if block["type"] == "text":
|
||||||
if isinstance(block, dict):
|
# Strip annotations
|
||||||
block_type = block.get("type")
|
new_content.append({"type": "text", "text": block["text"]})
|
||||||
if block_type == "text":
|
elif block["type"] in ("reasoning", "tool_call"):
|
||||||
# Strip annotations
|
pass
|
||||||
new_content.append({"type": "text", "text": block["text"]})
|
else:
|
||||||
elif block_type in ("reasoning", "tool_call"):
|
new_content.append(block)
|
||||||
pass
|
new_message = copy.copy(message)
|
||||||
else:
|
new_message.content = new_content
|
||||||
new_content.append(block)
|
|
||||||
else:
|
|
||||||
new_content.append(block)
|
|
||||||
return message.model_copy(update={"content": new_content})
|
|
||||||
|
|
||||||
return message
|
return new_message
|
||||||
|
|
||||||
|
|
||||||
# v1 / Responses
|
# v1 / Responses
|
||||||
@ -319,17 +318,18 @@ def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
|
|||||||
for field in ("end_index", "start_index", "title"):
|
for field in ("end_index", "start_index", "title"):
|
||||||
if field in annotation:
|
if field in annotation:
|
||||||
url_citation[field] = annotation[field]
|
url_citation[field] = annotation[field]
|
||||||
url_citation["type"] = "url_citation"
|
url_citation["type"] = "citation"
|
||||||
url_citation["url"] = annotation["url"]
|
url_citation["url"] = annotation["url"]
|
||||||
return url_citation
|
return url_citation
|
||||||
|
|
||||||
elif annotation_type == "file_citation":
|
elif annotation_type == "file_citation":
|
||||||
document_citation = {"type": "document_citation"}
|
document_citation = {"type": "citation"}
|
||||||
if "filename" in annotation:
|
if "filename" in annotation:
|
||||||
document_citation["title"] = annotation["filename"]
|
document_citation["title"] = annotation["filename"]
|
||||||
for field in ("file_id", "index"): # OpenAI-specific
|
if "file_id" in annotation:
|
||||||
if field in annotation:
|
document_citation["file_id"] = annotation["file_id"]
|
||||||
document_citation[field] = annotation[field]
|
if "index" in annotation:
|
||||||
|
document_citation["file_index"] = annotation["index"]
|
||||||
return document_citation
|
return document_citation
|
||||||
|
|
||||||
# TODO: standardise container_file_citation?
|
# TODO: standardise container_file_citation?
|
||||||
@ -367,13 +367,15 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
|
|||||||
yield new_block
|
yield new_block
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
def _convert_to_v1_from_responses(
|
||||||
|
content: list[dict[str, Any]],
|
||||||
|
tool_calls: Optional[list[types.ToolCall]] = None,
|
||||||
|
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
|
||||||
|
) -> list[types.ContentBlock]:
|
||||||
"""Mutate a Responses message to v1 format."""
|
"""Mutate a Responses message to v1 format."""
|
||||||
if not isinstance(message.content, list):
|
|
||||||
return message
|
|
||||||
|
|
||||||
def _iter_blocks() -> Iterable[dict[str, Any]]:
|
def _iter_blocks() -> Iterable[dict[str, Any]]:
|
||||||
for block in message.content:
|
for block in content:
|
||||||
if not isinstance(block, dict):
|
if not isinstance(block, dict):
|
||||||
continue
|
continue
|
||||||
block_type = block.get("type")
|
block_type = block.get("type")
|
||||||
@ -409,13 +411,24 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
|||||||
yield new_block
|
yield new_block
|
||||||
|
|
||||||
elif block_type == "function_call":
|
elif block_type == "function_call":
|
||||||
new_block = {"type": "tool_call", "id": block.get("call_id", "")}
|
new_block = None
|
||||||
if "id" in block:
|
call_id = block.get("call_id", "")
|
||||||
new_block["item_id"] = block["id"]
|
if call_id:
|
||||||
for extra_key in ("arguments", "name", "index"):
|
for tool_call in tool_calls or []:
|
||||||
if extra_key in block:
|
if tool_call.get("id") == call_id:
|
||||||
new_block[extra_key] = block[extra_key]
|
new_block = tool_call.copy()
|
||||||
yield new_block
|
break
|
||||||
|
else:
|
||||||
|
for invalid_tool_call in invalid_tool_calls or []:
|
||||||
|
if invalid_tool_call.get("id") == call_id:
|
||||||
|
new_block = invalid_tool_call.copy()
|
||||||
|
break
|
||||||
|
if new_block:
|
||||||
|
if "id" in block:
|
||||||
|
new_block["item_id"] = block["id"]
|
||||||
|
if "index" in block:
|
||||||
|
new_block["index"] = block["index"]
|
||||||
|
yield new_block
|
||||||
|
|
||||||
elif block_type == "web_search_call":
|
elif block_type == "web_search_call":
|
||||||
web_search_call = {"type": "web_search_call", "id": block["id"]}
|
web_search_call = {"type": "web_search_call", "id": block["id"]}
|
||||||
@ -485,28 +498,26 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
|
|||||||
new_block["index"] = new_block["value"].pop("index")
|
new_block["index"] = new_block["value"].pop("index")
|
||||||
yield new_block
|
yield new_block
|
||||||
|
|
||||||
# Replace the list with the fully converted one
|
return list(_iter_blocks())
|
||||||
message.content = list(_iter_blocks())
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]:
|
def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
|
||||||
annotation_type = annotation.get("type")
|
if annotation["type"] == "citation":
|
||||||
|
if "url" in annotation:
|
||||||
|
return {**annotation, "type": "url_citation"}
|
||||||
|
|
||||||
if annotation_type == "document_citation":
|
|
||||||
new_ann: dict[str, Any] = {"type": "file_citation"}
|
new_ann: dict[str, Any] = {"type": "file_citation"}
|
||||||
|
|
||||||
if "title" in annotation:
|
if "title" in annotation:
|
||||||
new_ann["filename"] = annotation["title"]
|
new_ann["filename"] = annotation["title"]
|
||||||
|
if "file_id" in annotation:
|
||||||
for fld in ("file_id", "index"):
|
new_ann["file_id"] = annotation["file_id"]
|
||||||
if fld in annotation:
|
if "file_index" in annotation:
|
||||||
new_ann[fld] = annotation[fld]
|
new_ann["index"] = annotation["file_index"]
|
||||||
|
|
||||||
return new_ann
|
return new_ann
|
||||||
|
|
||||||
elif annotation_type == "non_standard_annotation":
|
elif annotation["type"] == "non_standard_annotation":
|
||||||
return annotation["value"]
|
return annotation["value"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -528,7 +539,10 @@ def _implode_reasoning_blocks(blocks: list[dict[str, Any]]) -> Iterable[dict[str
|
|||||||
elif "reasoning" not in block and "summary" not in block:
|
elif "reasoning" not in block and "summary" not in block:
|
||||||
# {"type": "reasoning", "id": "rs_..."}
|
# {"type": "reasoning", "id": "rs_..."}
|
||||||
oai_format = {**block, "summary": []}
|
oai_format = {**block, "summary": []}
|
||||||
|
# Update key order
|
||||||
oai_format["type"] = oai_format.pop("type", "reasoning")
|
oai_format["type"] = oai_format.pop("type", "reasoning")
|
||||||
|
if "encrypted_content" in oai_format:
|
||||||
|
oai_format["encrypted_content"] = oai_format.pop("encrypted_content")
|
||||||
yield oai_format
|
yield oai_format
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@ -594,13 +608,11 @@ def _consolidate_calls(
|
|||||||
# If this really is the matching “result” – collapse
|
# If this really is the matching “result” – collapse
|
||||||
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
|
if nxt.get("type") == result_name and nxt.get("id") == current.get("id"):
|
||||||
if call_name == "web_search_call":
|
if call_name == "web_search_call":
|
||||||
collapsed = {
|
collapsed = {"id": current["id"]}
|
||||||
"id": current["id"],
|
|
||||||
"status": current["status"],
|
|
||||||
"type": "web_search_call",
|
|
||||||
}
|
|
||||||
if "action" in current:
|
if "action" in current:
|
||||||
collapsed["action"] = current["action"]
|
collapsed["action"] = current["action"]
|
||||||
|
collapsed["status"] = current["status"]
|
||||||
|
collapsed["type"] = "web_search_call"
|
||||||
|
|
||||||
if call_name == "code_interpreter_call":
|
if call_name == "code_interpreter_call":
|
||||||
collapsed = {"id": current["id"]}
|
collapsed = {"id": current["id"]}
|
||||||
@ -621,51 +633,50 @@ def _consolidate_calls(
|
|||||||
yield nxt
|
yield nxt
|
||||||
|
|
||||||
|
|
||||||
def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
def _convert_from_v1_to_responses(
|
||||||
if not isinstance(message.content, list):
|
content: list[types.ContentBlock], tool_calls: list[types.ToolCall]
|
||||||
return message
|
) -> list[dict[str, Any]]:
|
||||||
|
|
||||||
new_content: list = []
|
new_content: list = []
|
||||||
for block in message.content:
|
for block in content:
|
||||||
if isinstance(block, dict):
|
if block["type"] == "text" and "annotations" in block:
|
||||||
block_type = block.get("type")
|
# Need a copy because we’re changing the annotations list
|
||||||
if block_type == "text" and "annotations" in block:
|
new_block = dict(block)
|
||||||
# Need a copy because we’re changing the annotations list
|
new_block["annotations"] = [
|
||||||
new_block = dict(block)
|
_convert_annotation_from_v1(a) for a in block["annotations"]
|
||||||
new_block["annotations"] = [
|
]
|
||||||
_convert_annotation_from_v1(a) for a in block["annotations"]
|
new_content.append(new_block)
|
||||||
|
elif block["type"] == "tool_call":
|
||||||
|
new_block = {"type": "function_call", "call_id": block["id"]}
|
||||||
|
if "item_id" in block:
|
||||||
|
new_block["id"] = block["item_id"] # type: ignore[typeddict-item]
|
||||||
|
if "name" in block and "arguments" in block:
|
||||||
|
new_block["name"] = block["name"]
|
||||||
|
new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item]
|
||||||
|
else:
|
||||||
|
matching_tool_calls = [
|
||||||
|
call for call in tool_calls if call["id"] == block["id"]
|
||||||
]
|
]
|
||||||
new_content.append(new_block)
|
if matching_tool_calls:
|
||||||
elif block_type == "tool_call":
|
tool_call = matching_tool_calls[0]
|
||||||
new_block = {"type": "function_call", "call_id": block["id"]}
|
|
||||||
if "item_id" in block:
|
|
||||||
new_block["id"] = block["item_id"]
|
|
||||||
if "name" in block and "arguments" in block:
|
|
||||||
new_block["name"] = block["name"]
|
|
||||||
new_block["arguments"] = block["arguments"]
|
|
||||||
else:
|
|
||||||
tool_call = next(
|
|
||||||
call for call in message.tool_calls if call["id"] == block["id"]
|
|
||||||
)
|
|
||||||
if "name" not in block:
|
if "name" not in block:
|
||||||
new_block["name"] = tool_call["name"]
|
new_block["name"] = tool_call["name"]
|
||||||
if "arguments" not in block:
|
if "arguments" not in block:
|
||||||
new_block["arguments"] = json.dumps(tool_call["args"])
|
new_block["arguments"] = json.dumps(tool_call["args"])
|
||||||
new_content.append(new_block)
|
new_content.append(new_block)
|
||||||
elif (
|
elif (
|
||||||
is_data_content_block(block)
|
is_data_content_block(cast(dict, block))
|
||||||
and block["type"] == "image"
|
and block["type"] == "image"
|
||||||
and "base64" in block
|
and "base64" in block
|
||||||
):
|
and isinstance(block.get("id"), str)
|
||||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
and block["id"].startswith("ig_")
|
||||||
for extra_key in ("id", "status"):
|
):
|
||||||
if extra_key in block:
|
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||||
new_block[extra_key] = block[extra_key]
|
for extra_key in ("id", "status"):
|
||||||
new_content.append(new_block)
|
if extra_key in block:
|
||||||
elif block_type == "non_standard" and "value" in block:
|
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
|
||||||
new_content.append(block["value"])
|
new_content.append(new_block)
|
||||||
else:
|
elif block["type"] == "non_standard" and "value" in block:
|
||||||
new_content.append(block)
|
new_content.append(block["value"])
|
||||||
else:
|
else:
|
||||||
new_content.append(block)
|
new_content.append(block)
|
||||||
|
|
||||||
@ -679,4 +690,4 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return message.model_copy(update={"content": new_content})
|
return new_content
|
||||||
|
3813
libs/partners/openai/langchain_openai/chat_models/base_v1.py
Normal file
3813
libs/partners/openai/langchain_openai/chat_models/base_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -56,6 +56,8 @@ langchain-tests = { path = "../../standard-tests", editable = true }
|
|||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
disallow_untyped_defs = "True"
|
disallow_untyped_defs = "True"
|
||||||
|
disable_error_code = ["typeddict-unknown-key"]
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = "transformers"
|
module = "transformers"
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
Binary file not shown.
@ -14,16 +14,24 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI, ChatOpenAIV1
|
||||||
|
|
||||||
MODEL_NAME = "gpt-4o-mini"
|
MODEL_NAME = "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
def _check_response(response: Optional[BaseMessage]) -> None:
|
def _check_response(response: Optional[BaseMessage], output_version) -> None:
|
||||||
assert isinstance(response, AIMessage)
|
if output_version == "v1":
|
||||||
|
assert isinstance(response, AIMessageV1) or isinstance(
|
||||||
|
response, AIMessageChunkV1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, list)
|
assert isinstance(response.content, list)
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
assert isinstance(block, dict)
|
assert isinstance(block, dict)
|
||||||
@ -41,7 +49,10 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
|||||||
for key in ["end_index", "start_index", "title", "type", "url"]
|
for key in ["end_index", "start_index", "title", "type", "url"]
|
||||||
)
|
)
|
||||||
|
|
||||||
text_content = response.text()
|
if output_version == "v1":
|
||||||
|
text_content = response.text
|
||||||
|
else:
|
||||||
|
text_content = response.text()
|
||||||
assert isinstance(text_content, str)
|
assert isinstance(text_content, str)
|
||||||
assert text_content
|
assert text_content
|
||||||
assert response.usage_metadata
|
assert response.usage_metadata
|
||||||
@ -56,22 +67,34 @@ def _check_response(response: Optional[BaseMessage]) -> None:
|
|||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
||||||
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
||||||
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
if output_version == "v1":
|
||||||
|
llm = ChatOpenAIV1(model=MODEL_NAME)
|
||||||
|
else:
|
||||||
|
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
|
||||||
first_response = llm.invoke(
|
first_response = llm.invoke(
|
||||||
"What was a positive news story from today?",
|
"What was a positive news story from today?",
|
||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
)
|
)
|
||||||
_check_response(first_response)
|
_check_response(first_response, output_version)
|
||||||
|
|
||||||
# Test streaming
|
# Test streaming
|
||||||
full: Optional[BaseMessageChunk] = None
|
if isinstance(llm, ChatOpenAIV1):
|
||||||
for chunk in llm.stream(
|
full: Optional[AIMessageChunkV1] = None
|
||||||
"What was a positive news story from today?",
|
for chunk in llm.stream(
|
||||||
tools=[{"type": "web_search_preview"}],
|
"What was a positive news story from today?",
|
||||||
):
|
tools=[{"type": "web_search_preview"}],
|
||||||
assert isinstance(chunk, AIMessageChunk)
|
):
|
||||||
full = chunk if full is None else full + chunk
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
_check_response(full)
|
full = chunk if full is None else full + chunk
|
||||||
|
else:
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in llm.stream(
|
||||||
|
"What was a positive news story from today?",
|
||||||
|
tools=[{"type": "web_search_preview"}],
|
||||||
|
):
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
_check_response(full, output_version)
|
||||||
|
|
||||||
# Use OpenAI's stateful API
|
# Use OpenAI's stateful API
|
||||||
response = llm.invoke(
|
response = llm.invoke(
|
||||||
@ -79,38 +102,26 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
|
|||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
previous_response_id=first_response.response_metadata["id"],
|
previous_response_id=first_response.response_metadata["id"],
|
||||||
)
|
)
|
||||||
_check_response(response)
|
_check_response(response, output_version)
|
||||||
|
|
||||||
# Manually pass in chat history
|
# Manually pass in chat history
|
||||||
response = llm.invoke(
|
response = llm.invoke(
|
||||||
[
|
[
|
||||||
{
|
{"role": "user", "content": "What was a positive news story from today?"},
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What was a positive news story from today?",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
first_response,
|
first_response,
|
||||||
{
|
{"role": "user", "content": "what about a negative one"},
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "text": "what about a negative one"}],
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
)
|
)
|
||||||
_check_response(response)
|
_check_response(response, output_version)
|
||||||
|
|
||||||
# Bind tool
|
# Bind tool
|
||||||
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
|
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
|
||||||
"What was a positive news story from today?"
|
"What was a positive news story from today?"
|
||||||
)
|
)
|
||||||
_check_response(response)
|
_check_response(response, output_version)
|
||||||
|
|
||||||
for msg in [first_response, full, response]:
|
for msg in [first_response, full, response]:
|
||||||
assert isinstance(msg, AIMessage)
|
|
||||||
block_types = [block["type"] for block in msg.content] # type: ignore[index]
|
block_types = [block["type"] for block in msg.content] # type: ignore[index]
|
||||||
if output_version == "responses/v1":
|
if output_version == "responses/v1":
|
||||||
assert block_types == ["web_search_call", "text"]
|
assert block_types == ["web_search_call", "text"]
|
||||||
@ -125,7 +136,7 @@ async def test_web_search_async() -> None:
|
|||||||
"What was a positive news story from today?",
|
"What was a positive news story from today?",
|
||||||
tools=[{"type": "web_search_preview"}],
|
tools=[{"type": "web_search_preview"}],
|
||||||
)
|
)
|
||||||
_check_response(response)
|
_check_response(response, "v0")
|
||||||
assert response.response_metadata["status"]
|
assert response.response_metadata["status"]
|
||||||
|
|
||||||
# Test streaming
|
# Test streaming
|
||||||
@ -137,7 +148,7 @@ async def test_web_search_async() -> None:
|
|||||||
assert isinstance(chunk, AIMessageChunk)
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
assert isinstance(full, AIMessageChunk)
|
assert isinstance(full, AIMessageChunk)
|
||||||
_check_response(full)
|
_check_response(full, "v0")
|
||||||
|
|
||||||
for msg in [response, full]:
|
for msg in [response, full]:
|
||||||
assert msg.additional_kwargs["tool_outputs"]
|
assert msg.additional_kwargs["tool_outputs"]
|
||||||
@ -148,8 +159,8 @@ async def test_web_search_async() -> None:
|
|||||||
|
|
||||||
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
def test_function_calling(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||||
def multiply(x: int, y: int) -> int:
|
def multiply(x: int, y: int) -> int:
|
||||||
"""return x * y"""
|
"""return x * y"""
|
||||||
return x * y
|
return x * y
|
||||||
@ -170,7 +181,33 @@ def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -
|
|||||||
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
|
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
|
||||||
|
|
||||||
response = bound_llm.invoke("What was a positive news story from today?")
|
response = bound_llm.invoke("What was a positive news story from today?")
|
||||||
_check_response(response)
|
_check_response(response, output_version)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_function_calling.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_function_calling_v1() -> None:
|
||||||
|
def multiply(x: int, y: int) -> int:
|
||||||
|
"""return x * y"""
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
llm = ChatOpenAIV1(model=MODEL_NAME)
|
||||||
|
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
|
||||||
|
ai_msg = bound_llm.invoke("whats 5 * 4")
|
||||||
|
assert len(ai_msg.tool_calls) == 1
|
||||||
|
assert ai_msg.tool_calls[0]["name"] == "multiply"
|
||||||
|
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
|
||||||
|
|
||||||
|
full: Any = None
|
||||||
|
for chunk in bound_llm.stream("whats 5 * 4"):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert len(full.tool_calls) == 1
|
||||||
|
assert full.tool_calls[0]["name"] == "multiply"
|
||||||
|
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
|
||||||
|
|
||||||
|
response = bound_llm.invoke("What was a positive news story from today?")
|
||||||
|
_check_response(response, "v1")
|
||||||
|
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
@ -183,10 +220,8 @@ class FooDict(TypedDict):
|
|||||||
|
|
||||||
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_parsed_pydantic_schema(
|
def test_parsed_pydantic_schema(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||||
output_version: Literal["v0", "responses/v1", "v1"],
|
|
||||||
) -> None:
|
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model=MODEL_NAME, use_responses_api=True, output_version=output_version
|
model=MODEL_NAME, use_responses_api=True, output_version=output_version
|
||||||
)
|
)
|
||||||
@ -206,6 +241,28 @@ def test_parsed_pydantic_schema(
|
|||||||
assert parsed.response
|
assert parsed.response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_parsed_pydantic_schema.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_parsed_pydantic_schema_v1() -> None:
|
||||||
|
llm = ChatOpenAIV1(model=MODEL_NAME, use_responses_api=True)
|
||||||
|
response = llm.invoke("how are ya", response_format=Foo)
|
||||||
|
parsed = Foo(**json.loads(response.text))
|
||||||
|
assert parsed == response.parsed
|
||||||
|
assert parsed.response
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
chunks = []
|
||||||
|
for chunk in llm.stream("how are ya", response_format=Foo):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
parsed = Foo(**json.loads(full.text))
|
||||||
|
assert parsed == full.parsed
|
||||||
|
assert parsed.response
|
||||||
|
|
||||||
|
|
||||||
async def test_parsed_pydantic_schema_async() -> None:
|
async def test_parsed_pydantic_schema_async() -> None:
|
||||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||||
response = await llm.ainvoke("how are ya", response_format=Foo)
|
response = await llm.ainvoke("how are ya", response_format=Foo)
|
||||||
@ -311,8 +368,8 @@ def test_function_calling_and_structured_output() -> None:
|
|||||||
|
|
||||||
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||||
)
|
)
|
||||||
@ -337,6 +394,26 @@ def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
|||||||
assert block_types == ["reasoning", "text"]
|
assert block_types == ["reasoning", "text"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_reasoning_v1() -> None:
|
||||||
|
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
|
||||||
|
response = llm.invoke("Hello", reasoning={"effort": "low"})
|
||||||
|
assert isinstance(response, AIMessageV1)
|
||||||
|
|
||||||
|
# Test init params + streaming
|
||||||
|
llm = ChatOpenAIV1(model="o4-mini", reasoning={"effort": "low"})
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in llm.stream("Hello"):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
|
||||||
|
for msg in [response, full]:
|
||||||
|
block_types = [block["type"] for block in msg.content]
|
||||||
|
assert block_types == ["reasoning", "text"]
|
||||||
|
|
||||||
|
|
||||||
def test_stateful_api() -> None:
|
def test_stateful_api() -> None:
|
||||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||||
response = llm.invoke("how are you, my name is Bobo")
|
response = llm.invoke("how are you, my name is Bobo")
|
||||||
@ -380,14 +457,14 @@ def test_file_search() -> None:
|
|||||||
|
|
||||||
input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
|
input_message = {"role": "user", "content": "What is deep research by OpenAI?"}
|
||||||
response = llm.invoke([input_message], tools=[tool])
|
response = llm.invoke([input_message], tools=[tool])
|
||||||
_check_response(response)
|
_check_response(response, "v0")
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
full: Optional[BaseMessageChunk] = None
|
||||||
for chunk in llm.stream([input_message], tools=[tool]):
|
for chunk in llm.stream([input_message], tools=[tool]):
|
||||||
assert isinstance(chunk, AIMessageChunk)
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
assert isinstance(full, AIMessageChunk)
|
assert isinstance(full, AIMessageChunk)
|
||||||
_check_response(full)
|
_check_response(full, "v0")
|
||||||
|
|
||||||
next_message = {"role": "user", "content": "Thank you."}
|
next_message = {"role": "user", "content": "Thank you."}
|
||||||
_ = llm.invoke([input_message, full, next_message])
|
_ = llm.invoke([input_message, full, next_message])
|
||||||
@ -395,9 +472,9 @@ def test_file_search() -> None:
|
|||||||
|
|
||||||
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_stream_reasoning_summary(
|
def test_stream_reasoning_summary(
|
||||||
output_version: Literal["v0", "responses/v1", "v1"],
|
output_version: Literal["v0", "responses/v1"],
|
||||||
) -> None:
|
) -> None:
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model="o4-mini",
|
model="o4-mini",
|
||||||
@ -424,7 +501,8 @@ def test_stream_reasoning_summary(
|
|||||||
assert isinstance(block["type"], str)
|
assert isinstance(block["type"], str)
|
||||||
assert isinstance(block["text"], str)
|
assert isinstance(block["text"], str)
|
||||||
assert block["text"]
|
assert block["text"]
|
||||||
elif output_version == "responses/v1":
|
else:
|
||||||
|
# output_version == "responses/v1"
|
||||||
reasoning = next(
|
reasoning = next(
|
||||||
block
|
block
|
||||||
for block in response_1.content
|
for block in response_1.content
|
||||||
@ -438,18 +516,6 @@ def test_stream_reasoning_summary(
|
|||||||
assert isinstance(block["type"], str)
|
assert isinstance(block["type"], str)
|
||||||
assert isinstance(block["text"], str)
|
assert isinstance(block["text"], str)
|
||||||
assert block["text"]
|
assert block["text"]
|
||||||
else:
|
|
||||||
# v1
|
|
||||||
total_reasoning_blocks = 0
|
|
||||||
for block in response_1.content:
|
|
||||||
if block["type"] == "reasoning":
|
|
||||||
total_reasoning_blocks += 1
|
|
||||||
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
|
|
||||||
assert isinstance(block["reasoning"], str)
|
|
||||||
assert isinstance(block["index"], int)
|
|
||||||
assert (
|
|
||||||
total_reasoning_blocks > 1
|
|
||||||
) # This query typically generates multiple reasoning blocks
|
|
||||||
|
|
||||||
# Check we can pass back summaries
|
# Check we can pass back summaries
|
||||||
message_2 = {"role": "user", "content": "Thank you."}
|
message_2 = {"role": "user", "content": "Thank you."}
|
||||||
@ -457,10 +523,45 @@ def test_stream_reasoning_summary(
|
|||||||
assert isinstance(response_2, AIMessage)
|
assert isinstance(response_2, AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_stream_reasoning_summary_v1() -> None:
|
||||||
|
llm = ChatOpenAIV1(
|
||||||
|
model="o4-mini",
|
||||||
|
# Routes to Responses API if `reasoning` is set.
|
||||||
|
reasoning={"effort": "medium", "summary": "auto"},
|
||||||
|
)
|
||||||
|
message_1 = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "What was the third tallest buliding in the year 2000?",
|
||||||
|
}
|
||||||
|
response_1: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in llm.stream([message_1]):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
response_1 = chunk if response_1 is None else response_1 + chunk
|
||||||
|
assert isinstance(response_1, AIMessageChunkV1)
|
||||||
|
|
||||||
|
total_reasoning_blocks = 0
|
||||||
|
for block in response_1.content:
|
||||||
|
if block["type"] == "reasoning":
|
||||||
|
total_reasoning_blocks += 1
|
||||||
|
assert isinstance(block["id"], str) and block["id"].startswith("rs_")
|
||||||
|
assert isinstance(block["reasoning"], str)
|
||||||
|
assert isinstance(block["index"], int)
|
||||||
|
assert (
|
||||||
|
total_reasoning_blocks > 1
|
||||||
|
) # This query typically generates multiple reasoning blocks
|
||||||
|
|
||||||
|
# Check we can pass back summaries
|
||||||
|
message_2 = {"role": "user", "content": "Thank you."}
|
||||||
|
response_2 = llm.invoke([message_1, response_1, message_2])
|
||||||
|
assert isinstance(response_2, AIMessageV1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -> None:
|
def test_code_interpreter(output_version: Literal["v0", "responses/v1"]) -> None:
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model="o4-mini", use_responses_api=True, output_version=output_version
|
model="o4-mini", use_responses_api=True, output_version=output_version
|
||||||
)
|
)
|
||||||
@ -473,33 +574,20 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
|||||||
}
|
}
|
||||||
response = llm_with_tools.invoke([input_message])
|
response = llm_with_tools.invoke([input_message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
_check_response(response)
|
_check_response(response, output_version)
|
||||||
if output_version == "v0":
|
if output_version == "v0":
|
||||||
tool_outputs = [
|
tool_outputs = [
|
||||||
item
|
item
|
||||||
for item in response.additional_kwargs["tool_outputs"]
|
for item in response.additional_kwargs["tool_outputs"]
|
||||||
if item["type"] == "code_interpreter_call"
|
if item["type"] == "code_interpreter_call"
|
||||||
]
|
]
|
||||||
elif output_version == "responses/v1":
|
|
||||||
tool_outputs = [
|
|
||||||
item
|
|
||||||
for item in response.content
|
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# v1
|
# responses/v1
|
||||||
tool_outputs = [
|
tool_outputs = [
|
||||||
item
|
item
|
||||||
for item in response.content
|
for item in response.content
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||||
]
|
]
|
||||||
code_interpreter_result = next(
|
|
||||||
item
|
|
||||||
for item in response.content
|
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
|
|
||||||
)
|
|
||||||
assert tool_outputs
|
|
||||||
assert code_interpreter_result
|
|
||||||
assert len(tool_outputs) == 1
|
assert len(tool_outputs) == 1
|
||||||
|
|
||||||
# Test streaming
|
# Test streaming
|
||||||
@ -520,25 +608,65 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
|
|||||||
for item in response.additional_kwargs["tool_outputs"]
|
for item in response.additional_kwargs["tool_outputs"]
|
||||||
if item["type"] == "code_interpreter_call"
|
if item["type"] == "code_interpreter_call"
|
||||||
]
|
]
|
||||||
elif output_version == "responses/v1":
|
else:
|
||||||
|
# responses/v1
|
||||||
tool_outputs = [
|
tool_outputs = [
|
||||||
item
|
item
|
||||||
for item in response.content
|
for item in response.content
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
||||||
]
|
]
|
||||||
else:
|
assert tool_outputs
|
||||||
code_interpreter_call = next(
|
|
||||||
item
|
# Test we can pass back in
|
||||||
for item in response.content
|
next_message = {"role": "user", "content": "Please add more comments to the code."}
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
|
_ = llm_with_tools.invoke([input_message, full, next_message])
|
||||||
)
|
|
||||||
code_interpreter_result = next(
|
|
||||||
item
|
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||||
for item in response.content
|
@pytest.mark.vcr
|
||||||
if isinstance(item, dict) and item["type"] == "code_interpreter_result"
|
def test_code_interpreter_v1() -> None:
|
||||||
)
|
llm = ChatOpenAIV1(model="o4-mini", use_responses_api=True)
|
||||||
assert code_interpreter_call
|
llm_with_tools = llm.bind_tools(
|
||||||
assert code_interpreter_result
|
[{"type": "code_interpreter", "container": {"type": "auto"}}]
|
||||||
|
)
|
||||||
|
input_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write and run code to answer the question: what is 3^3?",
|
||||||
|
}
|
||||||
|
response = llm_with_tools.invoke([input_message])
|
||||||
|
assert isinstance(response, AIMessageV1)
|
||||||
|
_check_response(response, "v1")
|
||||||
|
|
||||||
|
tool_outputs = [
|
||||||
|
item for item in response.content if item["type"] == "code_interpreter_call"
|
||||||
|
]
|
||||||
|
code_interpreter_result = next(
|
||||||
|
item for item in response.content if item["type"] == "code_interpreter_result"
|
||||||
|
)
|
||||||
|
assert tool_outputs
|
||||||
|
assert code_interpreter_result
|
||||||
|
assert len(tool_outputs) == 1
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
# Use same container
|
||||||
|
container_id = tool_outputs[0]["container_id"]
|
||||||
|
llm_with_tools = llm.bind_tools(
|
||||||
|
[{"type": "code_interpreter", "container": container_id}]
|
||||||
|
)
|
||||||
|
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in llm_with_tools.stream([input_message]):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
code_interpreter_call = next(
|
||||||
|
item for item in full.content if item["type"] == "code_interpreter_call"
|
||||||
|
)
|
||||||
|
code_interpreter_result = next(
|
||||||
|
item for item in full.content if item["type"] == "code_interpreter_result"
|
||||||
|
)
|
||||||
|
assert code_interpreter_call
|
||||||
|
assert code_interpreter_result
|
||||||
assert tool_outputs
|
assert tool_outputs
|
||||||
|
|
||||||
# Test we can pass back in
|
# Test we can pass back in
|
||||||
@ -634,9 +762,59 @@ def test_mcp_builtin_zdr() -> None:
|
|||||||
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_mcp_builtin_zdr.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_mcp_builtin_zdr_v1() -> None:
|
||||||
|
llm = ChatOpenAIV1(
|
||||||
|
model="o4-mini", store=False, include=["reasoning.encrypted_content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_with_tools = llm.bind_tools(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "mcp",
|
||||||
|
"server_label": "deepwiki",
|
||||||
|
"server_url": "https://mcp.deepwiki.com/mcp",
|
||||||
|
"require_approval": {"always": {"tool_names": ["read_wiki_structure"]}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"What transport protocols does the 2025-03-26 version of the MCP spec "
|
||||||
|
"support?"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in llm_with_tools.stream([input_message]):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
|
||||||
|
assert isinstance(full, AIMessageChunkV1)
|
||||||
|
assert all(isinstance(block, dict) for block in full.content)
|
||||||
|
|
||||||
|
approval_message = HumanMessageV1(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "non_standard",
|
||||||
|
"value": {
|
||||||
|
"type": "mcp_approval_response",
|
||||||
|
"approve": True,
|
||||||
|
"approval_request_id": block["value"]["id"], # type: ignore[index]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for block in full.content
|
||||||
|
if block["type"] == "non_standard"
|
||||||
|
and block["value"]["type"] == "mcp_approval_request" # type: ignore[index]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_ = llm_with_tools.invoke([input_message, full, approval_message])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_image_generation_streaming(output_version: str) -> None:
|
def test_image_generation_streaming(output_version: str) -> None:
|
||||||
"""Test image generation streaming."""
|
"""Test image generation streaming."""
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
@ -710,9 +888,52 @@ def test_image_generation_streaming(output_version: str) -> None:
|
|||||||
assert set(standard_keys).issubset(tool_output.keys())
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_image_generation_streaming_v1() -> None:
|
||||||
|
"""Test image generation streaming."""
|
||||||
|
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
|
||||||
|
tool = {
|
||||||
|
"type": "image_generation",
|
||||||
|
"quality": "low",
|
||||||
|
"output_format": "jpeg",
|
||||||
|
"output_compression": 100,
|
||||||
|
"size": "1024x1024",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_keys = {
|
||||||
|
# Standard
|
||||||
|
"type",
|
||||||
|
"base64",
|
||||||
|
"mime_type",
|
||||||
|
"id",
|
||||||
|
"index",
|
||||||
|
# OpenAI-specific
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
"status",
|
||||||
|
}
|
||||||
|
|
||||||
|
full: Optional[AIMessageChunkV1] = None
|
||||||
|
for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]):
|
||||||
|
assert isinstance(chunk, AIMessageChunkV1)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
complete_ai_message = cast(AIMessageChunkV1, full)
|
||||||
|
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in complete_ai_message.content
|
||||||
|
if isinstance(block, dict) and block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(expected_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.parametrize("output_version", ["v0", "responses/v1", "v1"])
|
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
|
||||||
def test_image_generation_multi_turn(output_version: str) -> None:
|
def test_image_generation_multi_turn(output_version: str) -> None:
|
||||||
"""Test multi-turn editing of image generation by passing in history."""
|
"""Test multi-turn editing of image generation by passing in history."""
|
||||||
# Test multi-turn
|
# Test multi-turn
|
||||||
@ -735,7 +956,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
|||||||
]
|
]
|
||||||
ai_message = llm_with_tools.invoke(chat_history)
|
ai_message = llm_with_tools.invoke(chat_history)
|
||||||
assert isinstance(ai_message, AIMessage)
|
assert isinstance(ai_message, AIMessage)
|
||||||
_check_response(ai_message)
|
_check_response(ai_message, output_version)
|
||||||
|
|
||||||
expected_keys = {
|
expected_keys = {
|
||||||
"id",
|
"id",
|
||||||
@ -801,7 +1022,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
|||||||
|
|
||||||
ai_message2 = llm_with_tools.invoke(chat_history)
|
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||||
assert isinstance(ai_message2, AIMessage)
|
assert isinstance(ai_message2, AIMessage)
|
||||||
_check_response(ai_message2)
|
_check_response(ai_message2, output_version)
|
||||||
|
|
||||||
if output_version == "v0":
|
if output_version == "v0":
|
||||||
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
|
tool_output = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||||
@ -821,3 +1042,76 @@ def test_image_generation_multi_turn(output_version: str) -> None:
|
|||||||
if isinstance(block, dict) and block["type"] == "image"
|
if isinstance(block, dict) and block["type"] == "image"
|
||||||
)
|
)
|
||||||
assert set(standard_keys).issubset(tool_output.keys())
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
|
||||||
|
@pytest.mark.vcr
|
||||||
|
def test_image_generation_multi_turn_v1() -> None:
|
||||||
|
"""Test multi-turn editing of image generation by passing in history."""
|
||||||
|
# Test multi-turn
|
||||||
|
llm = ChatOpenAIV1(model="gpt-4.1", use_responses_api=True)
|
||||||
|
# Test invocation
|
||||||
|
tool = {
|
||||||
|
"type": "image_generation",
|
||||||
|
"quality": "low",
|
||||||
|
"output_format": "jpeg",
|
||||||
|
"output_compression": 100,
|
||||||
|
"size": "1024x1024",
|
||||||
|
}
|
||||||
|
llm_with_tools = llm.bind_tools([tool])
|
||||||
|
|
||||||
|
chat_history: list[MessageLikeRepresentation] = [
|
||||||
|
{"role": "user", "content": "Draw a random short word in green font."}
|
||||||
|
]
|
||||||
|
ai_message = llm_with_tools.invoke(chat_history)
|
||||||
|
assert isinstance(ai_message, AIMessageV1)
|
||||||
|
_check_response(ai_message, "v1")
|
||||||
|
|
||||||
|
expected_keys = {
|
||||||
|
# Standard
|
||||||
|
"type",
|
||||||
|
"base64",
|
||||||
|
"mime_type",
|
||||||
|
"id",
|
||||||
|
# OpenAI-specific
|
||||||
|
"background",
|
||||||
|
"output_format",
|
||||||
|
"quality",
|
||||||
|
"revised_prompt",
|
||||||
|
"size",
|
||||||
|
"status",
|
||||||
|
}
|
||||||
|
|
||||||
|
standard_keys = {"type", "base64", "id", "status"}
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in ai_message.content
|
||||||
|
if isinstance(block, dict) and block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(standard_keys).issubset(tool_output.keys())
|
||||||
|
|
||||||
|
chat_history.extend(
|
||||||
|
[
|
||||||
|
# AI message with tool output
|
||||||
|
ai_message,
|
||||||
|
# New request
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Now, change the font to blue. Keep the word and everything else "
|
||||||
|
"the same."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message2 = llm_with_tools.invoke(chat_history)
|
||||||
|
assert isinstance(ai_message2, AIMessageV1)
|
||||||
|
_check_response(ai_message2, "v1")
|
||||||
|
|
||||||
|
tool_output = next(
|
||||||
|
block
|
||||||
|
for block in ai_message2.content
|
||||||
|
if isinstance(block, dict) and block["type"] == "image"
|
||||||
|
)
|
||||||
|
assert set(expected_keys).issubset(tool_output.keys())
|
||||||
|
Loading…
Reference in New Issue
Block a user