diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index eaa39f931b4..9dcd4b7207a 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -307,7 +307,7 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage: id=message.id, name=message.name, tool_calls=message.tool_calls, - response_metadata=cast(dict, message.response_metadata), + response_metadata=cast("dict", message.response_metadata), ) if isinstance(message, AIMessageChunkV1): return AIMessageChunk( @@ -315,7 +315,7 @@ def _convert_from_v1_message(message: MessageV1) -> BaseMessage: id=message.id, name=message.name, tool_call_chunks=message.tool_call_chunks, - response_metadata=cast(dict, message.response_metadata), + response_metadata=cast("dict", message.response_metadata), ) if isinstance(message, HumanMessageV1): return HumanMessage( diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index b784cbcfe2f..de662cc6177 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -5,6 +5,8 @@ import uuid from dataclasses import dataclass, field from typing import Any, Literal, Optional, TypedDict, Union, cast, get_args +from pydantic import BaseModel + import langchain_core.messages.content_blocks as types from langchain_core.messages.ai import _LC_ID_PREFIX, UsageMetadata, add_usage from langchain_core.messages.base import merge_content @@ -32,20 +34,20 @@ def _ensure_id(id_val: Optional[str]) -> str: return id_val or str(uuid.uuid4()) -class Provider(TypedDict): - """Information about the provider that generated the message. +class ResponseMetadata(TypedDict, total=False): + """Metadata about the response from the AI provider. - Contains metadata about the AI provider and model used to generate content. + Contains additional information returned by the provider, such as + response headers, service tiers, log probabilities, system fingerprints, etc. - Attributes: - name: Name and version of the provider that created the content block. - model_name: Name of the model that generated the content block. + Extra keys are permitted from what is typed here. """ - name: str - """Name and version of the provider that created the content block.""" + model_provider: str + """Name and version of the provider that created the message (e.g., openai).""" + model_name: str - """Name of the model that generated the content block.""" + """Name of the model that generated the message.""" @dataclass @@ -91,21 +93,29 @@ class AIMessage: usage_metadata: Optional[UsageMetadata] = None """If provided, usage metadata for a message, such as token counts.""" - response_metadata: dict = field(default_factory=dict) + response_metadata: ResponseMetadata = field( + default_factory=lambda: ResponseMetadata() + ) """Metadata about the response. This field should include non-standard data returned by the provider, such as response headers, service tiers, or log probabilities. """ + parsed: Optional[Union[dict[str, Any], BaseModel]] = None + """Auto-parsed message contents, if applicable.""" + def __init__( self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None, name: Optional[str] = None, lc_version: str = "v1", - response_metadata: Optional[dict] = None, + response_metadata: Optional[ResponseMetadata] = None, usage_metadata: Optional[UsageMetadata] = None, + tool_calls: Optional[list[types.ToolCall]] = None, + invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None, + parsed: Optional[Union[dict[str, Any], BaseModel]] = None, ): """Initialize an AI message. @@ -116,6 +126,11 @@ class AIMessage: lc_version: Encoding version for the message. response_metadata: Optional metadata about the response. usage_metadata: Optional metadata about token usage. + tool_calls: Optional list of tool calls made by the AI. Tool calls should + generally be included in message content. If passed on init, they will + be added to the content list. + invalid_tool_calls: Optional list of tool calls that failed validation. + parsed: Optional auto-parsed message contents, if applicable. """ if isinstance(content, str): self.content = [{"type": "text", "text": content}] @@ -126,13 +141,27 @@ class AIMessage: self.name = name self.lc_version = lc_version self.usage_metadata = usage_metadata + self.parsed = parsed if response_metadata is None: self.response_metadata = {} else: self.response_metadata = response_metadata - self._tool_calls: list[types.ToolCall] = [] - self._invalid_tool_calls: list[types.InvalidToolCall] = [] + # Add tool calls to content if provided on init + if tool_calls: + content_tool_calls = { + block["id"] + for block in self.content + if block["type"] == "tool_call" and "id" in block + } + for tool_call in tool_calls: + if "id" in tool_call and tool_call["id"] in content_tool_calls: + continue + self.content.append(tool_call) + self._tool_calls = [ + block for block in self.content if block["type"] == "tool_call" + ] + self.invalid_tool_calls = invalid_tool_calls or [] @property def text(self) -> Optional[str]: @@ -150,7 +179,7 @@ class AIMessage: tool_calls = [block for block in self.content if block["type"] == "tool_call"] if tool_calls: self._tool_calls = tool_calls - return self._tool_calls + return [block for block in self.content if block["type"] == "tool_call"] @tool_calls.setter def tool_calls(self, value: list[types.ToolCall]) -> None: @@ -202,13 +231,16 @@ class AIMessageChunk: These data represent incremental usage statistics, as opposed to a running total. """ - response_metadata: dict = field(init=False) + response_metadata: ResponseMetadata = field(init=False) """Metadata about the response chunk. This field should include non-standard data returned by the provider, such as response headers, service tiers, or log probabilities. """ + parsed: Optional[Union[dict[str, Any], BaseModel]] = None + """Auto-parsed message contents, if applicable.""" + tool_call_chunks: list[types.ToolCallChunk] = field(init=False) def __init__( @@ -217,9 +249,10 @@ class AIMessageChunk: id: Optional[str] = None, name: Optional[str] = None, lc_version: str = "v1", - response_metadata: Optional[dict] = None, + response_metadata: Optional[ResponseMetadata] = None, usage_metadata: Optional[UsageMetadata] = None, tool_call_chunks: Optional[list[types.ToolCallChunk]] = None, + parsed: Optional[Union[dict[str, Any], BaseModel]] = None, ): """Initialize an AI message. @@ -231,6 +264,7 @@ class AIMessageChunk: response_metadata: Optional metadata about the response. usage_metadata: Optional metadata about token usage. tool_call_chunks: Optional list of partial tool call data. + parsed: Optional auto-parsed message contents, if applicable. """ if isinstance(content, str): self.content = [{"type": "text", "text": content, "index": 0}] @@ -241,6 +275,7 @@ class AIMessageChunk: self.name = name self.lc_version = lc_version self.usage_metadata = usage_metadata + self.parsed = parsed if response_metadata is None: self.response_metadata = {} else: @@ -251,7 +286,7 @@ class AIMessageChunk: self.tool_call_chunks = tool_call_chunks self._tool_calls: list[types.ToolCall] = [] - self._invalid_tool_calls: list[types.InvalidToolCall] = [] + self.invalid_tool_calls: list[types.InvalidToolCall] = [] self._init_tool_calls() def _init_tool_calls(self) -> None: @@ -264,7 +299,7 @@ class AIMessageChunk: ValueError: If the tool call chunks are malformed. """ self._tool_calls = [] - self._invalid_tool_calls = [] + self.invalid_tool_calls = [] if not self.tool_call_chunks: if self._tool_calls: self.tool_call_chunks = [ @@ -276,14 +311,14 @@ class AIMessageChunk: ) for tc in self._tool_calls ] - if self._invalid_tool_calls: + if self.invalid_tool_calls: tool_call_chunks = self.tool_call_chunks tool_call_chunks.extend( [ create_tool_call_chunk( name=tc["name"], args=tc["args"], id=tc["id"], index=None ) - for tc in self._invalid_tool_calls + for tc in self.invalid_tool_calls ] ) self.tool_call_chunks = tool_call_chunks @@ -317,7 +352,7 @@ class AIMessageChunk: except Exception: add_chunk_to_invalid_tool_calls(chunk) self._tool_calls = tool_calls - self._invalid_tool_calls = invalid_tool_calls + self.invalid_tool_calls = invalid_tool_calls @property def text(self) -> Optional[str]: @@ -361,6 +396,20 @@ class AIMessageChunk: error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk." raise NotImplementedError(error_msg) + def to_message(self) -> "AIMessage": + """Convert this AIMessageChunk to an AIMessage.""" + return AIMessage( + content=self.content, + id=self.id, + name=self.name, + lc_version=self.lc_version, + response_metadata=self.response_metadata, + usage_metadata=self.usage_metadata, + tool_calls=self.tool_calls, + invalid_tool_calls=self.invalid_tool_calls, + parsed=self.parsed, + ) + def add_ai_message_chunks( left: AIMessageChunk, *others: AIMessageChunk @@ -371,7 +420,8 @@ def add_ai_message_chunks( *(cast("list[str | dict[Any, Any]]", o.content) for o in others), ) response_metadata = merge_dicts( - left.response_metadata, *(o.response_metadata for o in others) + cast("dict", left.response_metadata), + *(cast("dict", o.response_metadata) for o in others), ) # Merge tool call chunks @@ -398,6 +448,15 @@ def add_ai_message_chunks( else: usage_metadata = None + # Parsed + # 'parsed' always represents an aggregation not an incremental value, so the last + # non-null value is kept. + parsed = None + for m in reversed([left, *others]): + if m.parsed is not None: + parsed = m.parsed + break + chunk_id = None candidates = [left.id] + [o.id for o in others] # first pass: pick the first non-run-* id @@ -415,8 +474,9 @@ def add_ai_message_chunks( return left.__class__( content=cast("list[types.ContentBlock]", content), tool_call_chunks=tool_call_chunks, - response_metadata=response_metadata, + response_metadata=cast("ResponseMetadata", response_metadata), usage_metadata=usage_metadata, + parsed=parsed, id=chunk_id, ) @@ -453,19 +513,25 @@ class HumanMessage: """ def __init__( - self, content: Union[str, list[types.ContentBlock]], id: Optional[str] = None + self, + content: Union[str, list[types.ContentBlock]], + *, + id: Optional[str] = None, + name: Optional[str] = None, ): """Initialize a human message. Args: content: Message content as string or list of content blocks. id: Optional unique identifier for the message. + name: Optional human-readable name for the message. """ self.id = _ensure_id(id) if isinstance(content, str): self.content = [{"type": "text", "text": content}] else: self.content = content + self.name = name def text(self) -> str: """Extract all text content from the message. @@ -495,20 +561,47 @@ class SystemMessage: content: list[types.ContentBlock] type: Literal["system"] = "system" + name: Optional[str] = None + """An optional name for the message. + + This can be used to provide a human-readable name for the message. + + Usage of this field is optional, and whether it's used or not is up to the + model implementation. + """ + + custom_role: Optional[str] = None + """If provided, a custom role for the system message. + + Example: ``"developer"``. + + Integration packages may use this field to assign the system message role if it + contains a recognized value. + """ + def __init__( - self, content: Union[str, list[types.ContentBlock]], *, id: Optional[str] = None + self, + content: Union[str, list[types.ContentBlock]], + *, + id: Optional[str] = None, + custom_role: Optional[str] = None, + name: Optional[str] = None, ): - """Initialize a system message. + """Initialize a human message. Args: - content: System instructions as string or list of content blocks. + content: Message content as string or list of content blocks. id: Optional unique identifier for the message. + custom_role: If provided, a custom role for the system message. + name: Optional human-readable name for the message. """ self.id = _ensure_id(id) if isinstance(content, str): self.content = [{"type": "text", "text": content}] else: self.content = content + self.custom_role = custom_role + self.name = name def text(self) -> str: """Extract all text content from the system message.""" @@ -535,11 +628,51 @@ class ToolMessage: id: str tool_call_id: str - content: list[dict[str, Any]] + content: list[types.ContentBlock] artifact: Optional[Any] = None # App-side payload not for the model + + name: Optional[str] = None + """An optional name for the message. + + This can be used to provide a human-readable name for the message. + + Usage of this field is optional, and whether it's used or not is up to the + model implementation. + """ + status: Literal["success", "error"] = "success" type: Literal["tool"] = "tool" + def __init__( + self, + content: Union[str, list[types.ContentBlock]], + tool_call_id: str, + *, + id: Optional[str] = None, + name: Optional[str] = None, + artifact: Optional[Any] = None, + status: Literal["success", "error"] = "success", + ): + """Initialize a human message. + + Args: + content: Message content as string or list of content blocks. + tool_call_id: ID of the tool call this message responds to. + id: Optional unique identifier for the message. + name: Optional human-readable name for the message. + artifact: Optional app-side payload not intended for the model. + status: Execution status ("success" or "error"). + """ + self.id = _ensure_id(id) + self.tool_call_id = tool_call_id + if isinstance(content, str): + self.content = [{"type": "text", "text": content}] + else: + self.content = content + self.name = name + self.artifact = artifact + self.status = status + @property def text(self) -> str: """Extract all text content from the tool message.""" diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 681516368f0..68007a7c8ce 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -14,16 +14,18 @@ from typing_extensions import TypedDict, overload from langchain_core.load.serializable import Serializable from langchain_core.messages import ( + AIMessage, AnyMessage, BaseMessage, HumanMessage, + SystemMessage, + ToolMessage, get_buffer_string, ) from langchain_core.messages import content_blocks as types -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 -from langchain_core.messages.v1 import MessageV1 +from langchain_core.messages.v1 import MessageV1, ResponseMetadata from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 @@ -40,7 +42,7 @@ def _convert_to_v1(message: BaseMessage) -> MessageV1: if isinstance(block, str): content.append({"type": "text", "text": block}) elif isinstance(block, dict): - content.append(block) + content.append(cast("types.ContentBlock", block)) else: pass @@ -52,7 +54,7 @@ def _convert_to_v1(message: BaseMessage) -> MessageV1: return AIMessageV1( content=content, usage_metadata=message.usage_metadata, - response_metadata=message.response_metadata, + response_metadata=cast("ResponseMetadata", message.response_metadata), tool_calls=message.tool_calls, ) if isinstance(message, SystemMessage): @@ -92,8 +94,18 @@ class PromptValue(Serializable, ABC): def to_string(self) -> str: """Return prompt value as string.""" + @overload + def to_messages( + self, output_version: Literal["v0"] = "v0" + ) -> list[BaseMessage]: ... + + @overload + def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ... + @abstractmethod - def to_messages(self) -> list[BaseMessage]: + def to_messages( + self, output_version: Literal["v0", "v1"] = "v0" + ) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]: """Return prompt as a list of Messages.""" @@ -117,10 +129,6 @@ class StringPromptValue(PromptValue): """Return prompt as string.""" return self.text - def to_messages(self) -> list[BaseMessage]: - """Return prompt as messages.""" - return [HumanMessage(content=self.text)] - @overload def to_messages( self, output_version: Literal["v0"] = "v0" @@ -131,12 +139,8 @@ class StringPromptValue(PromptValue): def to_messages( self, output_version: Literal["v0", "v1"] = "v0" - ) -> Union[list[BaseMessage], list[MessageV1]]: - """Return prompt as a list of messages. - - Args: - output_version: The output version, either "v0" (default) or "v1". - """ + ) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]: + """Return prompt as messages.""" if output_version == "v1": return [HumanMessageV1(content=self.text)] return [HumanMessage(content=self.text)] @@ -165,7 +169,7 @@ class ChatPromptValue(PromptValue): def to_messages( self, output_version: Literal["v0", "v1"] = "v0" - ) -> Union[list[BaseMessage], list[MessageV1]]: + ) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]: """Return prompt as a list of messages. Args: @@ -207,8 +211,26 @@ class ImagePromptValue(PromptValue): """Return prompt (image URL) as string.""" return self.image_url["url"] - def to_messages(self) -> list[BaseMessage]: + @overload + def to_messages( + self, output_version: Literal["v0"] = "v0" + ) -> list[BaseMessage]: ... + + @overload + def to_messages(self, output_version: Literal["v1"]) -> list[MessageV1]: ... + + def to_messages( + self, output_version: Literal["v0", "v1"] = "v0" + ) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]: """Return prompt (image URL) as messages.""" + if output_version == "v1": + block: types.ImageContentBlock = { + "type": "image", + "url": self.image_url["url"], + } + if "detail" in self.image_url: + block["detail"] = self.image_url["detail"] + return [HumanMessageV1(content=[block])] return [HumanMessage(content=[cast("dict", self.image_url)])] diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index ededbcc3625..fe324f3eaee 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" } strict = "True" strict_bytes = "True" enable_error_code = "deprecated" +disable_error_code = ["typeddict-unknown-key"] # TODO: activate for 'strict' checking disallow_any_generics = "False" diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index 47d73106069..bcd5b392b67 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -66,11 +66,14 @@ For backwards compatibility, this module provides functions to convert between t formats. The functions are used internally by ChatOpenAI. """ # noqa: E501 +import copy import json from collections.abc import Iterable, Iterator from typing import Any, Literal, Union, cast from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block +from langchain_core.messages import content_blocks as types +from langchain_core.messages.v1 import AIMessage as AIMessageV1 _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" @@ -289,25 +292,21 @@ def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessa return cast(AIMessageChunk, result) -def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage: +def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1: """Convert a v1 message to the Chat Completions format.""" - if isinstance(message.content, list): - new_content: list = [] - for block in message.content: - if isinstance(block, dict): - block_type = block.get("type") - if block_type == "text": - # Strip annotations - new_content.append({"type": "text", "text": block["text"]}) - elif block_type in ("reasoning", "tool_call"): - pass - else: - new_content.append(block) - else: - new_content.append(block) - return message.model_copy(update={"content": new_content}) + new_content: list[types.ContentBlock] = [] + for block in message.content: + if block["type"] == "text": + # Strip annotations + new_content.append({"type": "text", "text": block["text"]}) + elif block["type"] in ("reasoning", "tool_call"): + pass + else: + new_content.append(block) + new_message = copy.copy(message) + new_message.content = new_content - return message + return new_message # v1 / Responses @@ -367,13 +366,13 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]: yield new_block -def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: +def _convert_to_v1_from_responses( + content: list[dict[str, Any]], +) -> list[types.ContentBlock]: """Mutate a Responses message to v1 format.""" - if not isinstance(message.content, list): - return message def _iter_blocks() -> Iterable[dict[str, Any]]: - for block in message.content: + for block in content: if not isinstance(block, dict): continue block_type = block.get("type") @@ -485,16 +484,14 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: new_block["index"] = new_block["value"].pop("index") yield new_block - # Replace the list with the fully converted one - message.content = list(_iter_blocks()) - - return message + return list(_iter_blocks()) -def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]: - annotation_type = annotation.get("type") +def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]: + if annotation["type"] == "citation": + if "url" in annotation: + return dict(annotation) - if annotation_type == "document_citation": new_ann: dict[str, Any] = {"type": "file_citation"} if "title" in annotation: @@ -502,11 +499,11 @@ def _convert_annotation_from_v1(annotation: dict[str, Any]) -> dict[str, Any]: for fld in ("file_id", "index"): if fld in annotation: - new_ann[fld] = annotation[fld] + new_ann[fld] = annotation[fld] # type: ignore[typeddict-item] return new_ann - elif annotation_type == "non_standard_annotation": + elif annotation["type"] == "non_standard_annotation": return annotation["value"] else: @@ -621,51 +618,48 @@ def _consolidate_calls( yield nxt -def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: - if not isinstance(message.content, list): - return message - +def _convert_from_v1_to_responses( + content: list[types.ContentBlock], tool_calls: list[types.ToolCall] +) -> list[dict[str, Any]]: new_content: list = [] - for block in message.content: - if isinstance(block, dict): - block_type = block.get("type") - if block_type == "text" and "annotations" in block: - # Need a copy because we’re changing the annotations list - new_block = dict(block) - new_block["annotations"] = [ - _convert_annotation_from_v1(a) for a in block["annotations"] + for block in content: + if block["type"] == "text" and "annotations" in block: + # Need a copy because we’re changing the annotations list + new_block = dict(block) + new_block["annotations"] = [ + _convert_annotation_from_v1(a) for a in block["annotations"] + ] + new_content.append(new_block) + elif block["type"] == "tool_call": + new_block = {"type": "function_call", "call_id": block["id"]} + if "item_id" in block: + new_block["id"] = block["item_id"] # type: ignore[typeddict-item] + if "name" in block and "arguments" in block: + new_block["name"] = block["name"] + new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item] + else: + matching_tool_calls = [ + call for call in tool_calls if call["id"] == block["id"] ] - new_content.append(new_block) - elif block_type == "tool_call": - new_block = {"type": "function_call", "call_id": block["id"]} - if "item_id" in block: - new_block["id"] = block["item_id"] - if "name" in block and "arguments" in block: - new_block["name"] = block["name"] - new_block["arguments"] = block["arguments"] - else: - tool_call = next( - call for call in message.tool_calls if call["id"] == block["id"] - ) + if matching_tool_calls: + tool_call = matching_tool_calls[0] if "name" not in block: new_block["name"] = tool_call["name"] if "arguments" not in block: new_block["arguments"] = json.dumps(tool_call["args"]) - new_content.append(new_block) - elif ( - is_data_content_block(block) - and block["type"] == "image" - and "base64" in block - ): - new_block = {"type": "image_generation_call", "result": block["base64"]} - for extra_key in ("id", "status"): - if extra_key in block: - new_block[extra_key] = block[extra_key] - new_content.append(new_block) - elif block_type == "non_standard" and "value" in block: - new_content.append(block["value"]) - else: - new_content.append(block) + new_content.append(new_block) + elif ( + is_data_content_block(cast(dict, block)) + and block["type"] == "image" + and "base64" in block + ): + new_block = {"type": "image_generation_call", "result": block["base64"]} + for extra_key in ("id", "status"): + if extra_key in block: + new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item] + new_content.append(new_block) + elif block["type"] == "non_standard" and "value" in block: + new_content.append(block["value"]) else: new_content.append(block) @@ -679,4 +673,4 @@ def _convert_from_v1_to_responses(message: AIMessage) -> AIMessage: ) ) - return message.model_copy(update={"content": new_content}) + return new_content diff --git a/libs/partners/openai/langchain_openai/chat_models/base_v1.py b/libs/partners/openai/langchain_openai/chat_models/base_v1.py index ba26e8a9dd9..a75e2c99772 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base_v1.py +++ b/libs/partners/openai/langchain_openai/chat_models/base_v1.py @@ -32,7 +32,6 @@ from urllib.parse import urlparse import certifi import openai import tiktoken -from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -50,21 +49,18 @@ from langchain_core.messages import ( convert_to_openai_data_block, is_data_content_block, ) -from langchain_core.messages.v1 import ( - AIMessage as AIMessageV1, - AIMessageChunk as AIMessageChunkV1, - HumanMessage as HumanMessageV1, - MessageV1, - SystemMessage as SystemMessageV1, - ToolMessage as ToolMessageV1, -) from langchain_core.messages.ai import ( InputTokenDetails, OutputTokenDetails, UsageMetadata, ) -from langchain_core.messages import content_blocks as types from langchain_core.messages.tool import tool_call_chunk +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1, ResponseMetadata +from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 +from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, @@ -72,7 +68,6 @@ from langchain_core.output_parsers.openai_tools import ( make_invalid_tool_call, parse_tool_call, ) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import ( Runnable, RunnableLambda, @@ -102,12 +97,8 @@ from langchain_openai.chat_models._client_utils import ( _get_default_httpx_client, ) from langchain_openai.chat_models._compat import ( - _convert_from_v03_ai_message, _convert_from_v1_to_chat_completions, _convert_from_v1_to_responses, - _convert_to_v03_ai_message, - _convert_to_v1_from_chat_completions, - _convert_to_v1_from_chat_completions_chunk, _convert_to_v1_from_responses, ) @@ -158,6 +149,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> MessageV1: invalid_tool_calls.append( make_invalid_tool_call(raw_tool_call, str(e)) ) + content.extend(tool_calls) if audio := _dict.get("audio"): # TODO: populate standard fields content.append({"type": "audio", "audio": audio}) @@ -165,14 +157,15 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> MessageV1: content=content, name=name, id=id_, - tool_calls=tool_calls, - invalid_tool_calls=invalid_tool_calls, + tool_calls=cast(list[ToolCall], tool_calls), + invalid_tool_calls=cast(list[InvalidToolCall], invalid_tool_calls), ) elif role in ("system", "developer"): return SystemMessageV1( content=_dict.get("content", ""), name=name, id=id_, + custom_role=role if role == "developer" else None, ) elif role == "tool": return ToolMessageV1( @@ -193,8 +186,9 @@ def _format_message_content(content: Any) -> Any: for block in content: # Remove unexpected block types if ( - block["type"] == "non_standard" - and block["value"].get("type") in ("tool_use", "thinking", "reasoning_content") + isinstance(block, dict) + and "type" in block + and block["type"] in ("tool_use", "thinking", "reasoning_content") ): continue elif isinstance(block, dict) and is_data_content_block(block): @@ -230,7 +224,7 @@ def _format_message_content(content: Any) -> Any: return formatted_content -def _convert_message_to_dict(message: BaseMessage) -> dict: +def _convert_message_to_dict(message: MessageV1) -> dict: """Convert a LangChain message to a dictionary. Args: @@ -240,15 +234,13 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: The dictionary. """ message_dict: dict[str, Any] = {"content": _format_message_content(message.content)} - if (name := message.name or message.additional_kwargs.get("name")) is not None: + if name := message.name: message_dict["name"] = name # populate role and additional message data - if isinstance(message, ChatMessage): - message_dict["role"] = message.role - elif isinstance(message, HumanMessage): + if isinstance(message, HumanMessageV1): message_dict["role"] = "user" - elif isinstance(message, AIMessage): + elif isinstance(message, AIMessageV1): message_dict["role"] = "assistant" if message.tool_calls or message.invalid_tool_calls: message_dict["tool_calls"] = [ @@ -257,40 +249,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: _lc_invalid_tool_call_to_openai_tool_call(tc) for tc in message.invalid_tool_calls ] - elif "tool_calls" in message.additional_kwargs: - message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] - tool_call_supported_props = {"id", "type", "function"} - message_dict["tool_calls"] = [ - {k: v for k, v in tool_call.items() if k in tool_call_supported_props} - for tool_call in message_dict["tool_calls"] - ] - elif "function_call" in message.additional_kwargs: - # OpenAI raises 400 if both function_call and tool_calls are present in the - # same message. - message_dict["function_call"] = message.additional_kwargs["function_call"] else: pass # If tool calls present, content null value should be None not empty string. - if "function_call" in message_dict or "tool_calls" in message_dict: + if "tool_calls" in message_dict: message_dict["content"] = message_dict["content"] or None - if "audio" in message.additional_kwargs: - # openai doesn't support passing the data back - only the id - # https://platform.openai.com/docs/guides/audio/multi-turn-conversations - raw_audio = message.additional_kwargs["audio"] - audio = ( - {"id": message.additional_kwargs["audio"]["id"]} - if "id" in raw_audio - else raw_audio - ) + audio: Optional[dict[str, Any]] = None + for block in message.content: + if block.get("type") == "audio" and (id_ := block.get("id")): + # openai doesn't support passing the data back - only the id + # https://platform.openai.com/docs/guides/audio/multi-turn-conversations + audio = {"id": id_} + if audio: message_dict["audio"] = audio - elif isinstance(message, SystemMessage): - message_dict["role"] = message.additional_kwargs.get( - "__openai_role__", "system" - ) - elif isinstance(message, FunctionMessage): - message_dict["role"] = "function" - elif isinstance(message, ToolMessage): + elif isinstance(message, SystemMessageV1): + if message.custom_role == "developer": + message_dict["role"] = "developer" + else: + message_dict["role"] = "system" + elif isinstance(message, ToolMessageV1): message_dict["role"] = "tool" message_dict["tool_call_id"] = message.tool_call_id @@ -301,21 +279,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] -) -> BaseMessageChunk: +def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> AIMessageChunkV1: id_ = _dict.get("id") - role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") - additional_kwargs: dict = {} - if _dict.get("function_call"): - function_call = dict(_dict["function_call"]) - if "name" in function_call and function_call["name"] is None: - function_call["name"] = "" - additional_kwargs["function_call"] = function_call tool_call_chunks = [] if raw_tool_calls := _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = raw_tool_calls try: tool_call_chunks = [ tool_call_chunk( @@ -329,33 +297,7 @@ def _convert_delta_to_message_chunk( except KeyError: pass - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content, id=id_) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk( - content=content, - additional_kwargs=additional_kwargs, - id=id_, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - ) - elif role in ("system", "developer") or default_class == SystemMessageChunk: - if role == "developer": - additional_kwargs = {"__openai_role__": "developer"} - else: - additional_kwargs = {} - return SystemMessageChunk( - content=content, id=id_, additional_kwargs=additional_kwargs - ) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) - elif role == "tool" or default_class == ToolMessageChunk: - return ToolMessageChunk( - content=content, tool_call_id=_dict["tool_call_id"], id=id_ - ) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role, id=id_) - else: - return default_class(content=content, id=id_) # type: ignore + return AIMessageChunkV1(content=content, id=id_, tool_call_chunks=tool_call_chunks) def _update_token_usage( @@ -421,12 +363,12 @@ _DictOrPydantic = Union[dict, _BM] class _AllReturnType(TypedDict): - raw: BaseMessage + raw: AIMessageV1 parsed: Optional[_DictOrPydantic] parsing_error: Optional[BaseException] -class BaseChatOpenAI(BaseChatModel): +class BaseChatOpenAIV1(BaseChatModel): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: root_client: Any = Field(default=None, exclude=True) #: :meta private: @@ -784,37 +726,9 @@ class BaseChatOpenAI(BaseChatModel): return params - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: - overall_token_usage: dict = {} - system_fingerprint = None - for output in llm_outputs: - if output is None: - # Happens in streaming - continue - token_usage = output.get("token_usage") - if token_usage is not None: - for k, v in token_usage.items(): - if v is None: - continue - if k in overall_token_usage: - overall_token_usage[k] = _update_token_usage( - overall_token_usage[k], v - ) - else: - overall_token_usage[k] = v - if system_fingerprint is None: - system_fingerprint = output.get("system_fingerprint") - combined = {"token_usage": overall_token_usage, "model_name": self.model_name} - if system_fingerprint: - combined["system_fingerprint"] = system_fingerprint - return combined - - def _convert_chunk_to_generation_chunk( - self, - chunk: dict, - default_chunk_class: type, - base_generation_info: Optional[dict], - ) -> Optional[ChatGenerationChunk]: + def _convert_chunk_to_message_chunk( + self, chunk: dict, base_generation_info: Optional[dict] + ) -> Optional[AIMessageChunkV1]: if chunk.get("type") == "content.delta": # from beta.chat.completions.stream return None token_usage = chunk.get("usage") @@ -829,23 +743,17 @@ class BaseChatOpenAI(BaseChatModel): ) if len(choices) == 0: # logprobs is implicitly None - generation_chunk = ChatGenerationChunk( - message=default_chunk_class(content="", usage_metadata=usage_metadata), - generation_info=base_generation_info, + return AIMessageChunkV1( + content=[], + usage_metadata=usage_metadata, + response_metadata=cast(ResponseMetadata, base_generation_info), ) - if self.output_version == "v1": - generation_chunk.message = _convert_to_v1_from_chat_completions_chunk( - cast(AIMessageChunk, generation_chunk.message) - ) - return generation_chunk choice = choices[0] if choice["delta"] is None: return None - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) + message_chunk = _convert_delta_to_message_chunk(choice["delta"]) generation_info = {**base_generation_info} if base_generation_info else {} if finish_reason := choice.get("finish_reason"): @@ -861,35 +769,22 @@ class BaseChatOpenAI(BaseChatModel): if logprobs: generation_info["logprobs"] = logprobs - if usage_metadata and isinstance(message_chunk, AIMessageChunk): + if usage_metadata: message_chunk.usage_metadata = usage_metadata - if self.output_version == "v1": - message_chunk = cast(AIMessageChunk, message_chunk) - # Convert to v1 format - if isinstance(message_chunk.content, str): - message_chunk = _convert_to_v1_from_chat_completions_chunk( - message_chunk - ) - if message_chunk.content: - message_chunk.content[0]["index"] = 0 # type: ignore[index] - else: - message_chunk = _convert_to_v1_from_chat_completions_chunk( - message_chunk - ) - - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) - return generation_chunk + message_chunk.response_metadata = { + **message_chunk.response_metadata, + **generation_info, + } + return message_chunk def _stream_responses( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> Iterator[AIMessageChunkV1]: kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) if self.include_response_headers: @@ -908,7 +803,6 @@ class BaseChatOpenAI(BaseChatModel): current_index = -1 current_output_index = -1 current_sub_index = -1 - has_reasoning = False for chunk in response: metadata = headers if is_first_chunk else {} ( @@ -923,7 +817,6 @@ class BaseChatOpenAI(BaseChatModel): current_sub_index, schema=original_schema_obj, metadata=metadata, - has_reasoning=has_reasoning, output_version=self.output_version, ) if generation_chunk: @@ -932,17 +825,15 @@ class BaseChatOpenAI(BaseChatModel): generation_chunk.text, chunk=generation_chunk ) is_first_chunk = False - if "reasoning" in generation_chunk.message.additional_kwargs: - has_reasoning = True yield generation_chunk async def _astream_responses( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncIterator[AIMessageChunkV1]: kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) if self.include_response_headers: @@ -963,7 +854,6 @@ class BaseChatOpenAI(BaseChatModel): current_index = -1 current_output_index = -1 current_sub_index = -1 - has_reasoning = False async for chunk in response: metadata = headers if is_first_chunk else {} ( @@ -978,7 +868,6 @@ class BaseChatOpenAI(BaseChatModel): current_sub_index, schema=original_schema_obj, metadata=metadata, - has_reasoning=has_reasoning, output_version=self.output_version, ) if generation_chunk: @@ -987,8 +876,6 @@ class BaseChatOpenAI(BaseChatModel): generation_chunk.text, chunk=generation_chunk ) is_first_chunk = False - if "reasoning" in generation_chunk.message.additional_kwargs: - has_reasoning = True yield generation_chunk def _should_stream_usage( @@ -1012,19 +899,18 @@ class BaseChatOpenAI(BaseChatModel): def _stream( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> Iterator[AIMessageChunkV1]: kwargs["stream"] = True stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} if "response_format" in payload: @@ -1050,43 +936,34 @@ class BaseChatOpenAI(BaseChatModel): for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - generation_chunk = self._convert_chunk_to_generation_chunk( - chunk, - default_chunk_class, - base_generation_info if is_first_chunk else {}, + message_chunk = self._convert_chunk_to_message_chunk( + chunk, base_generation_info if is_first_chunk else {} ) - if generation_chunk is None: + if message_chunk is None: continue - default_chunk_class = generation_chunk.message.__class__ - logprobs = (generation_chunk.generation_info or {}).get("logprobs") + logprobs = message_chunk.response_metadata.get("logprobs") if run_manager: run_manager.on_llm_new_token( - generation_chunk.text, - chunk=generation_chunk, - logprobs=logprobs, + message_chunk.text, chunk=message_chunk, logprobs=logprobs ) is_first_chunk = False - yield generation_chunk + yield message_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = response.get_final_completion() - generation_chunk = self._get_generation_chunk_from_completion( - final_completion - ) + message_chunk = self._get_message_chunk_from_completion(final_completion) if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - yield generation_chunk + run_manager.on_llm_new_token(message_chunk.text, chunk=message_chunk) + yield message_chunk def _generate( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> ChatResult: + ) -> AIMessageV1: if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs @@ -1119,10 +996,7 @@ class BaseChatOpenAI(BaseChatModel): else: response = self.root_client.responses.create(**payload) return _construct_lc_result_from_responses_api( - response, - schema=original_schema_obj, - metadata=generation_info, - output_version=self.output_version, + response, schema=original_schema_obj, metadata=generation_info ) elif self.include_response_headers: raw_response = self.client.with_raw_response.create(**payload) @@ -1130,7 +1004,7 @@ class BaseChatOpenAI(BaseChatModel): generation_info = {"headers": dict(raw_response.headers)} else: response = self.client.create(**payload) - return self._create_chat_result(response, generation_info) + return self._create_ai_message(response, generation_info) def _use_responses_api(self, payload: dict) -> bool: if isinstance(self.use_responses_api, bool): @@ -1155,7 +1029,7 @@ class BaseChatOpenAI(BaseChatModel): stop: Optional[list[str]] = None, **kwargs: Any, ) -> dict: - messages = self._convert_input(input_).to_messages() + messages = self._convert_input(input_).to_messages(output_version="v1") if stop is not None: kwargs["stop"] = stop @@ -1172,19 +1046,17 @@ class BaseChatOpenAI(BaseChatModel): else: payload["messages"] = [ _convert_message_to_dict(_convert_from_v1_to_chat_completions(m)) - if isinstance(m, AIMessage) + if isinstance(m, AIMessageV1) else _convert_message_to_dict(m) for m in messages ] return payload - def _create_chat_result( + def _create_ai_message( self, response: Union[dict, openai.BaseModel], generation_info: Optional[dict] = None, - ) -> ChatResult: - generations = [] - + ) -> AIMessageV1: response_dict = ( response if isinstance(response, dict) else response.model_dump() ) @@ -1209,8 +1081,8 @@ class BaseChatOpenAI(BaseChatModel): token_usage = response_dict.get("usage") for res in choices: - message = _convert_dict_to_message(res["message"]) - if token_usage and isinstance(message, AIMessage): + message = cast(AIMessageV1, _convert_dict_to_message(res["message"])) + if token_usage: message.usage_metadata = _create_usage_metadata(token_usage) generation_info = generation_info or {} generation_info["finish_reason"] = ( @@ -1220,11 +1092,10 @@ class BaseChatOpenAI(BaseChatModel): ) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] - gen = ChatGeneration(message=message, generation_info=generation_info) - generations.append(gen) + message.response_metadata = {**message.response_metadata, **generation_info} llm_output = { - "token_usage": token_usage, "model_name": response_dict.get("model", self.model_name), + "model_provider": "openai", "system_fingerprint": response_dict.get("system_fingerprint", ""), } if "id" in response_dict: @@ -1235,34 +1106,31 @@ class BaseChatOpenAI(BaseChatModel): if isinstance(response, openai.BaseModel) and getattr( response, "choices", None ): - message = response.choices[0].message # type: ignore[attr-defined] - if hasattr(message, "parsed"): - generations[0].message.additional_kwargs["parsed"] = message.parsed - if hasattr(message, "refusal"): - generations[0].message.additional_kwargs["refusal"] = message.refusal + oai_message = response.choices[0].message # type: ignore[attr-defined] + if hasattr(oai_message, "parsed"): + message.parsed = oai_message.parsed + if refusal := getattr(oai_message, "refusal", None): + message.content.append( + {"type": "non_standard", "value": {"refusal": refusal}} + ) - if self.output_version == "v1": - _ = llm_output.pop("token_usage", None) - generations[0].message = _convert_to_v1_from_chat_completions( - cast(AIMessage, generations[0].message) - ) - return ChatResult(generations=generations, llm_output=llm_output) + message.response_metadata = {**message.response_metadata, **llm_output} # type: ignore[typeddict-item] + return message async def _astream( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncIterator[AIMessageChunkV1]: kwargs["stream"] = True stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} if "response_format" in payload: @@ -1292,43 +1160,36 @@ class BaseChatOpenAI(BaseChatModel): async for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - generation_chunk = self._convert_chunk_to_generation_chunk( - chunk, - default_chunk_class, - base_generation_info if is_first_chunk else {}, + message_chunk = self._convert_chunk_to_message_chunk( + chunk, base_generation_info if is_first_chunk else {} ) - if generation_chunk is None: + if message_chunk is None: continue - default_chunk_class = generation_chunk.message.__class__ - logprobs = (generation_chunk.generation_info or {}).get("logprobs") + logprobs = message_chunk.response_metadata.get("logprobs") if run_manager: await run_manager.on_llm_new_token( - generation_chunk.text, - chunk=generation_chunk, - logprobs=logprobs, + message_chunk.text, chunk=message_chunk, logprobs=logprobs ) is_first_chunk = False - yield generation_chunk + yield message_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = await response.get_final_completion() - generation_chunk = self._get_generation_chunk_from_completion( - final_completion - ) + message_chunk = self._get_message_chunk_from_completion(final_completion) if run_manager: await run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk + message_chunk.text, chunk=message_chunk ) - yield generation_chunk + yield message_chunk async def _agenerate( self, - messages: list[BaseMessage], + messages: list[MessageV1], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> ChatResult: + ) -> AIMessageV1: if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs @@ -1365,10 +1226,7 @@ class BaseChatOpenAI(BaseChatModel): else: response = await self.root_async_client.responses.create(**payload) return _construct_lc_result_from_responses_api( - response, - schema=original_schema_obj, - metadata=generation_info, - output_version=self.output_version, + response, schema=original_schema_obj, metadata=generation_info ) elif self.include_response_headers: raw_response = await self.async_client.with_raw_response.create(**payload) @@ -1377,7 +1235,7 @@ class BaseChatOpenAI(BaseChatModel): else: response = await self.async_client.create(**payload) return await run_in_executor( - None, self._create_chat_result, response, generation_info + None, self._create_ai_message, response, generation_info ) @property @@ -1458,7 +1316,7 @@ class BaseChatOpenAI(BaseChatModel): def get_num_tokens_from_messages( self, - messages: list[BaseMessage], + messages: list[MessageV1], tools: Optional[ Sequence[Union[dict[str, Any], type, Callable, BaseTool]] ] = None, @@ -1483,8 +1341,6 @@ class BaseChatOpenAI(BaseChatModel): warnings.warn( "Counting tokens in tool schemas is not yet supported. Ignoring tools." ) - if sys.version_info[1] <= 7: - return super().get_num_tokens_from_messages(messages) model, encoding = self._get_encoding_model() if model.startswith("gpt-3.5-turbo-0301"): # every message follows {role/name}\n{content}\n @@ -1554,64 +1410,6 @@ class BaseChatOpenAI(BaseChatModel): num_tokens += 3 return num_tokens - @deprecated( - since="0.2.1", - alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools", - removal="1.0.0", - ) - def bind_functions( - self, - functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], - function_call: Optional[ - Union[_FunctionCall, str, Literal["auto", "none"]] - ] = None, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind functions (and other objects) to this chat model. - - Assumes model is compatible with OpenAI function-calling API. - - NOTE: Using bind_tools is recommended instead, as the `functions` and - `function_call` request parameters are officially marked as deprecated by - OpenAI. - - Args: - functions: A list of function definitions to bind to this chat model. - Can be a dictionary, pydantic model, or callable. Pydantic - models and callables will be automatically converted to - their schema dictionary representation. - function_call: Which function to require the model to call. - Must be the name of the single provided function or - "auto" to automatically determine which function to call - (if any). - **kwargs: Any additional parameters to pass to the - :class:`~langchain.runnable.Runnable` constructor. - """ - - formatted_functions = [convert_to_openai_function(fn) for fn in functions] - if function_call is not None: - function_call = ( - {"name": function_call} - if isinstance(function_call, str) - and function_call not in ("auto", "none") - else function_call - ) - if isinstance(function_call, dict) and len(formatted_functions) != 1: - raise ValueError( - "When specifying `function_call`, you must provide exactly one " - "function." - ) - if ( - isinstance(function_call, dict) - and formatted_functions[0]["name"] != function_call["name"] - ): - raise ValueError( - f"Function call {function_call} was specified, but the only " - f"provided function was {formatted_functions[0]['name']}." - ) - kwargs = {**kwargs, "function_call": function_call} - return super().bind(functions=formatted_functions, **kwargs) - def bind_tools( self, tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], @@ -1622,7 +1420,7 @@ class BaseChatOpenAI(BaseChatModel): strict: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None, **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: + ) -> Runnable[LanguageModelInput, MessageV1]: """Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. @@ -1752,7 +1550,7 @@ class BaseChatOpenAI(BaseChatModel): include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True - then both the raw model response (a BaseMessage) and the parsed model + then both the raw model response (an AIMessage) and the parsed model response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". @@ -1822,7 +1620,7 @@ class BaseChatOpenAI(BaseChatModel): | If ``include_raw`` is True, then Runnable outputs a dict with keys: - - "raw": BaseMessage + - "raw": AIMessage - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - "parsing_error": Optional[BaseException] @@ -1986,30 +1784,20 @@ class BaseChatOpenAI(BaseChatModel): filtered[k] = v return filtered - def _get_generation_chunk_from_completion( + def _get_message_chunk_from_completion( self, completion: openai.BaseModel - ) -> ChatGenerationChunk: + ) -> AIMessageChunkV1: """Get chunk from completion (e.g., from final completion of a stream).""" - chat_result = self._create_chat_result(completion) - chat_message = chat_result.generations[0].message - if isinstance(chat_message, AIMessage): - usage_metadata = chat_message.usage_metadata - # Skip tool_calls, already sent as chunks - if "tool_calls" in chat_message.additional_kwargs: - chat_message.additional_kwargs.pop("tool_calls") - else: - usage_metadata = None - message = AIMessageChunk( + ai_message = self._create_ai_message(completion) + return AIMessageChunkV1( content="", - additional_kwargs=chat_message.additional_kwargs, - usage_metadata=usage_metadata, - ) - return ChatGenerationChunk( - message=message, generation_info=chat_result.llm_output + usage_metadata=ai_message.usage_metadata, + response_metadata=ai_message.response_metadata, + parsed=ai_message.parsed, ) -class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] +class ChatOpenAI(BaseChatOpenAIV1): # type: ignore[override] """OpenAI chat model integration. .. dropdown:: Setup @@ -2702,7 +2490,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] message["role"] = "developer" return payload - def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[AIMessageChunkV1]: """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): return super()._stream_responses(*args, **kwargs) @@ -2711,7 +2499,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] async def _astream( self, *args: Any, **kwargs: Any - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncIterator[AIMessageChunkV1]: """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): async for chunk in super()._astream_responses(*args, **kwargs): @@ -2772,7 +2560,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True - then both the raw model response (a BaseMessage) and the parsed model + then both the raw model response (an AIMessage) and the parsed model response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". @@ -2847,7 +2635,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] | If ``include_raw`` is True, then Runnable outputs a dict with keys: - - "raw": BaseMessage + - "raw": AIMessage - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - "parsing_error": Optional[BaseException] @@ -3225,15 +3013,24 @@ def _convert_to_openai_response_format( def _oai_structured_outputs_parser( - ai_msg: AIMessage, schema: type[_BM] + ai_msg: AIMessageV1, schema: type[_BM] ) -> Optional[PydanticBaseModel]: - if parsed := ai_msg.additional_kwargs.get("parsed"): + if parsed := ai_msg.parsed: if isinstance(parsed, dict): return schema(**parsed) else: return parsed - elif ai_msg.additional_kwargs.get("refusal"): - raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) + elif any( + block["type"] == "non_standard" and block["value"].get("type") == "refusal" + for block in ai_msg.content + ): + refusal = next( + block["value"]["text"] + for block in ai_msg.content + if block["type"] == "non_standard" + and block["value"].get("type") == "refusal" + ) + raise OpenAIRefusalError(refusal) elif ai_msg.tool_calls: return None else: @@ -3335,12 +3132,12 @@ def _use_responses_api(payload: dict) -> bool: def _get_last_messages( - messages: Sequence[BaseMessage], -) -> tuple[Sequence[BaseMessage], Optional[str]]: + messages: Sequence[MessageV1], +) -> tuple[Sequence[MessageV1], Optional[str]]: """ Return 1. Every message after the most-recent AIMessage that has a non-empty - ``response_metadata["id"]`` (may be an empty list), + ``id`` (may be an empty list), 2. That id. If the most-recent AIMessage does not have an id (or there is no @@ -3348,9 +3145,9 @@ def _get_last_messages( """ for i in range(len(messages) - 1, -1, -1): msg = messages[i] - if isinstance(msg, AIMessage): - response_id = msg.response_metadata.get("id") - if response_id: + if isinstance(msg, AIMessageV1): + response_id = msg.id + if response_id and response_id.startswith("resp_"): return messages[i + 1 :], response_id else: return messages, None @@ -3359,7 +3156,7 @@ def _get_last_messages( def _construct_responses_api_payload( - messages: Sequence[BaseMessage], payload: dict + messages: Sequence[MessageV1], payload: dict ) -> dict: # Rename legacy parameters for legacy_token_param in ["max_tokens", "max_completion_tokens"]: @@ -3447,26 +3244,16 @@ def _construct_responses_api_payload( return payload -def _make_computer_call_output_from_message(message: ToolMessage) -> dict: - computer_call_output: dict = { - "call_id": message.tool_call_id, - "type": "computer_call_output", - } - if isinstance(message.content, list): - # Use first input_image block - output = next( - block - for block in message.content - if cast(dict, block)["type"] == "input_image" - ) - else: - # string, assume image_url - output = {"type": "input_image", "image_url": message.content} - computer_call_output["output"] = output - if "acknowledged_safety_checks" in message.additional_kwargs: - computer_call_output["acknowledged_safety_checks"] = message.additional_kwargs[ - "acknowledged_safety_checks" - ] +def _make_computer_call_output_from_message(message: ToolMessageV1) -> Optional[dict]: + computer_call_output = None + for block in message.content: + if ( + block["type"] == "non_standard" + and block["value"].get("type") == "computer_call_output" + ): + computer_call_output = block["value"] + break + return computer_call_output @@ -3484,23 +3271,24 @@ def _pop_index_and_sub_index(block: dict) -> dict: return new_block -def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: +def _construct_responses_api_input(messages: Sequence[MessageV1]) -> list: """Construct the input for the OpenAI Responses API.""" input_ = [] for lc_msg in messages: - if isinstance(lc_msg, AIMessage): - lc_msg = _convert_from_v03_ai_message(lc_msg) - lc_msg = _convert_from_v1_to_responses(lc_msg) msg = _convert_message_to_dict(lc_msg) + if isinstance(lc_msg, AIMessageV1): + msg["content"] = _convert_from_v1_to_responses( + msg["content"], lc_msg.tool_calls + ) # "name" parameter unsupported if "name" in msg: msg.pop("name") if msg["role"] == "tool": tool_output = msg["content"] - if lc_msg.additional_kwargs.get("type") == "computer_call_output": - computer_call_output = _make_computer_call_output_from_message( - cast(ToolMessage, lc_msg) - ) + computer_call_output = _make_computer_call_output_from_message( + cast(ToolMessageV1, lc_msg) + ) + if computer_call_output: input_.append(computer_call_output) else: if not isinstance(tool_output, str): @@ -3634,8 +3422,7 @@ def _construct_lc_result_from_responses_api( response: Response, schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, - output_version: str = "v0", -) -> ChatResult: +) -> AIMessageV1: """Construct ChatResponse from OpenAI Response API response.""" if response.error: raise ValueError(response.error) @@ -3661,6 +3448,7 @@ def _construct_lc_result_from_responses_api( if metadata: response_metadata.update(metadata) # for compatibility with chat completion calls. + response_metadata["model_provider"] = "openai" response_metadata["model_name"] = response_metadata.get("model") if response.usage: usage_metadata = _create_usage_metadata_responses(response.usage.model_dump()) @@ -3668,9 +3456,9 @@ def _construct_lc_result_from_responses_api( usage_metadata = None content_blocks: list = [] - tool_calls = [] - invalid_tool_calls = [] - additional_kwargs: dict = {} + tool_calls: list[ToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] + parsed = None for output in response.output: if output.type == "message": for content in output.content: @@ -3686,7 +3474,7 @@ def _construct_lc_result_from_responses_api( } content_blocks.append(block) if hasattr(content, "parsed"): - additional_kwargs["parsed"] = content.parsed + parsed = content.parsed if content.type == "refusal": content_blocks.append( {"type": "refusal", "refusal": content.refusal, "id": output.id} @@ -3706,7 +3494,7 @@ def _construct_lc_result_from_responses_api( "args": args, "id": output.call_id, } - tool_calls.append(tool_call) + tool_calls.append(cast(ToolCall, tool_call)) else: tool_call = { "type": "invalid_tool_call", @@ -3715,7 +3503,7 @@ def _construct_lc_result_from_responses_api( "id": output.call_id, "error": error, } - invalid_tool_calls.append(tool_call) + invalid_tool_calls.append(cast(InvalidToolCall, tool_call)) elif output.type in ( "reasoning", "web_search_call", @@ -3746,7 +3534,7 @@ def _construct_lc_result_from_responses_api( # ) if ( schema is not None - and "parsed" not in additional_kwargs + and not parsed and response.output_text # tool calls can generate empty output text and response.text and (text_config := response.text.model_dump()) @@ -3759,44 +3547,40 @@ def _construct_lc_result_from_responses_api( parsed = schema(**parsed_dict) else: parsed = parsed_dict - additional_kwargs["parsed"] = parsed except json.JSONDecodeError: pass - message = AIMessage( - content=content_blocks, + + content_v1 = _convert_to_v1_from_responses(content_blocks) + message = AIMessageV1( + content=content_v1, id=response.id, usage_metadata=usage_metadata, - response_metadata=response_metadata, - additional_kwargs=additional_kwargs, + response_metadata=cast(ResponseMetadata, response_metadata), tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, + parsed=parsed, ) - if output_version == "v0": - message = _convert_to_v03_ai_message(message) - elif output_version == "v1": - message = _convert_to_v1_from_responses(message) - if response.tools and any( - tool.type == "image_generation" for tool in response.tools - ): - # Get mime_time from tool definition and add to image generations - # if missing (primarily for tracing purposes). - image_generation_call = next( - tool for tool in response.tools if tool.type == "image_generation" - ) - if image_generation_call.output_format: - mime_type = f"image/{image_generation_call.output_format}" - for content_block in message.content: - # OK to mutate output message - if ( - isinstance(content_block, dict) - and content_block.get("type") == "image" - and "base64" in content_block - and "mime_type" not in block - ): - block["mime_type"] = mime_type - else: - pass - return ChatResult(generations=[ChatGeneration(message=message)]) + if response.tools and any( + tool.type == "image_generation" for tool in response.tools + ): + # Get mime_time from tool definition and add to image generations + # if missing (primarily for tracing purposes). + image_generation_call = next( + tool for tool in response.tools if tool.type == "image_generation" + ) + if image_generation_call.output_format: + mime_type = f"image/{image_generation_call.output_format}" + for content_block in message.content: + # OK to mutate output message + if ( + isinstance(content_block, dict) + and content_block.get("type") == "image" + and "base64" in content_block + and "mime_type" not in block + ): + block["mime_type"] = mime_type + + return message def _convert_responses_chunk_to_generation_chunk( @@ -3806,9 +3590,8 @@ def _convert_responses_chunk_to_generation_chunk( current_sub_index: int, # index of content block in output item schema: Optional[type[_BM]] = None, metadata: Optional[dict] = None, - has_reasoning: bool = False, output_version: str = "v0", -) -> tuple[int, int, int, Optional[ChatGenerationChunk]]: +) -> tuple[int, int, int, Optional[AIMessageChunkV1]]: def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: """Advance indexes tracked during streaming. @@ -3855,9 +3638,9 @@ def _convert_responses_chunk_to_generation_chunk( content = [] tool_call_chunks: list = [] - additional_kwargs: dict = {} + parsed = None if metadata: - response_metadata = metadata + response_metadata = cast(ResponseMetadata, metadata) else: response_metadata = {} usage_metadata = None @@ -3899,21 +3682,13 @@ def _convert_responses_chunk_to_generation_chunk( id = chunk.response.id response_metadata["id"] = chunk.response.id # Backwards compatibility elif chunk.type == "response.completed": - msg = cast( - AIMessage, - ( - _construct_lc_result_from_responses_api( - chunk.response, schema=schema, output_version=output_version - ) - .generations[0] - .message - ), - ) - if parsed := msg.additional_kwargs.get("parsed"): - additional_kwargs["parsed"] = parsed + msg = _construct_lc_result_from_responses_api(chunk.response, schema=schema) + if msg.parsed: + parsed = msg.parsed usage_metadata = msg.usage_metadata response_metadata = { - k: v for k, v in msg.response_metadata.items() if k != "id" + **response_metadata, + **{k: v for k, v in msg.response_metadata.items() if k != "id"}, # type: ignore[typeddict-item] } elif chunk.type == "response.output_item.added" and chunk.item.type == "message": if output_version == "v0": @@ -4019,33 +3794,22 @@ def _convert_responses_chunk_to_generation_chunk( else: return current_index, current_output_index, current_sub_index, None - message = AIMessageChunk( - content=content, # type: ignore[arg-type] + content_v1 = _convert_to_v1_from_responses(content) + for content_block in content_v1: + if ( + isinstance(content_block, dict) + and content_block.get("index", -1) > current_index + ): + # blocks were added for v1 + current_index = content_block["index"] + + message = AIMessageChunkV1( + content=content_v1, tool_call_chunks=tool_call_chunks, usage_metadata=usage_metadata, response_metadata=response_metadata, - additional_kwargs=additional_kwargs, + parsed=parsed, id=id, ) - if output_version == "v0": - message = cast( - AIMessageChunk, - _convert_to_v03_ai_message(message, has_reasoning=has_reasoning), - ) - elif output_version == "v1": - message = cast(AIMessageChunk, _convert_to_v1_from_responses(message)) - for content_block in message.content: - if ( - isinstance(content_block, dict) - and content_block.get("index", -1) > current_index - ): - # blocks were added for v1 - current_index = content_block["index"] - else: - pass - return ( - current_index, - current_output_index, - current_sub_index, - ChatGenerationChunk(message=message), - ) + + return (current_index, current_output_index, current_sub_index, message) diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index 5bffdabcf44..a54595796de 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -56,6 +56,8 @@ langchain-tests = { path = "../../standard-tests", editable = true } [tool.mypy] disallow_untyped_defs = "True" +disable_error_code = ["typeddict-unknown-key"] + [[tool.mypy.overrides]] module = "transformers" ignore_missing_imports = true