fix(openai): fix tracing and typing on standard outputs branch (#32326)

This commit is contained in:
ccurme 2025-07-30 14:02:15 -03:00 committed by GitHub
parent 8cf97e838c
commit 309d1a232a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 335 additions and 535 deletions

View File

@ -253,19 +253,20 @@ def shielded(func: Func) -> Func:
def _convert_llm_events(
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
) -> tuple[tuple[Any, ...], dict[str, Any]]:
args_list = list(args)
if (
event_name == "on_chat_model_start"
and isinstance(args[1], list)
and args[1]
and isinstance(args[1][0], MessageV1Types)
and isinstance(args_list[1], list)
and args_list[1]
and isinstance(args_list[1][0], MessageV1Types)
):
batch = [
convert_from_v1_message(item)
for item in args[1]
for item in args_list[1]
if isinstance(item, MessageV1Types)
]
args[1] = [batch] # type: ignore[index]
args_list[1] = [batch]
elif (
event_name == "on_llm_new_token"
and "chunk" in kwargs
@ -273,12 +274,21 @@ def _convert_llm_events(
):
chunk = kwargs["chunk"]
kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk)
elif event_name == "on_llm_end" and isinstance(args[0], MessageV1Types):
args[0] = LLMResult( # type: ignore[index]
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
elif event_name == "on_llm_end" and isinstance(args_list[0], MessageV1Types):
args_list[0] = LLMResult(
generations=[
[
ChatGeneration(
text=args_list[0].text,
message=convert_from_v1_message(args_list[0]),
)
]
]
)
else:
return
pass
return tuple(args_list), kwargs
def handle_event(
@ -310,7 +320,7 @@ def handle_event(
handler, ignore_condition_name
):
if not handler.accepts_new_messages:
_convert_llm_events(event_name, args, kwargs)
args, kwargs = _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
coros.append(event)
@ -406,7 +416,7 @@ async def _ahandle_event_for_handler(
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if not handler.accepts_new_messages:
_convert_llm_events(event_name, args, kwargs)
args, kwargs = _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)

View File

@ -328,20 +328,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
does not properly support streaming.
"""
output_version: str = "v0"
"""Version of AIMessage output format to use.
This field is used to roll-out new output formats for chat model AIMessages
in a backwards-compatible way.
``'v1'`` standardizes output format using a list of typed ContentBlock dicts. We
recommend this for new applications.
All chat models currently support the default of ``"v0"``.
.. versionadded:: 0.4
"""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:

View File

@ -7,6 +7,7 @@ import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from functools import cached_property
from operator import itemgetter
from typing import (
TYPE_CHECKING,
@ -41,6 +42,7 @@ from langchain_core.language_models.base import (
_get_token_ids_default_method,
_get_verbosity,
)
from langchain_core.load import dumpd
from langchain_core.messages import (
convert_to_openai_image_block,
get_buffer_string,
@ -312,6 +314,10 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
arbitrary_types_allowed=True,
)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# --- Runnable methods ---
@field_validator("verbose", mode="before")
@ -434,7 +440,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
{},
self._serialized,
_format_for_tracing(messages),
invocation_params=params,
options=options,
@ -500,7 +506,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
{},
self._serialized,
_format_for_tracing(messages),
invocation_params=params,
options=options,
@ -578,7 +584,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
{},
self._serialized,
_format_for_tracing(messages),
invocation_params=params,
options=options,
@ -647,7 +653,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
{},
self._serialized,
_format_for_tracing(messages),
invocation_params=params,
options=options,

View File

@ -233,6 +233,8 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
index: NotRequired[int]
"""Index of block in aggregate response. Used during streaming."""
type: Literal["invalid_tool_call"]

View File

@ -300,9 +300,8 @@ def test_llm_representation_for_serializable() -> None:
assert chat._get_llm_string() == (
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
'"test_cache", "CustomChat"], "kwargs": {"messages": {"id": '
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}, '
'"output_version": "v0"}, "lc": 1, "name": "CustomChat", "type": '
"\"constructor\"}---[('stop', None)]"
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
'1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
)

View File

@ -763,6 +763,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -2201,6 +2205,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({

View File

@ -1166,6 +1166,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({

View File

@ -2711,6 +2711,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -4193,6 +4197,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -5706,6 +5714,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -7094,6 +7106,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -8618,6 +8634,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -10051,6 +10071,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -11483,6 +11507,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@ -12957,6 +12985,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({

View File

@ -205,7 +205,6 @@ def test_configurable_with_default() -> None:
"name": None,
"bound": {
"name": None,
"output_version": "v0",
"disable_streaming": False,
"model": "claude-3-sonnet-20240229",
"mcp_servers": None,

View File

@ -71,7 +71,7 @@ import json
from collections.abc import Iterable, Iterator
from typing import Any, Literal, Optional, Union, cast
from langchain_core.messages import AIMessage, AIMessageChunk, is_data_content_block
from langchain_core.messages import AIMessage, is_data_content_block
from langchain_core.messages import content_blocks as types
from langchain_core.messages.v1 import AIMessage as AIMessageV1
@ -266,32 +266,6 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
# v1 / Chat Completions
def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
"""Mutate a Chat Completions message to v1 format."""
if isinstance(message.content, str):
if message.content:
message.content = [{"type": "text", "text": message.content}]
else:
message.content = []
for tool_call in message.tool_calls:
if id_ := tool_call.get("id"):
message.content.append({"type": "tool_call", "id": id_})
if "tool_calls" in message.additional_kwargs:
_ = message.additional_kwargs.pop("tool_calls")
if "token_usage" in message.response_metadata:
_ = message.response_metadata.pop("token_usage")
return message
def _convert_to_v1_from_chat_completions_chunk(chunk: AIMessageChunk) -> AIMessageChunk:
result = _convert_to_v1_from_chat_completions(cast(AIMessage, chunk))
return cast(AIMessageChunk, result)
def _convert_from_v1_to_chat_completions(message: AIMessageV1) -> AIMessageV1:
"""Convert a v1 message to the Chat Completions format."""
new_content: list[types.ContentBlock] = []
@ -341,14 +315,14 @@ def _convert_annotation_to_v1(annotation: dict[str, Any]) -> dict[str, Any]:
return non_standard_annotation
def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
if block.get("type") != "reasoning" or "summary" not in block:
yield block
def _explode_reasoning(block: dict[str, Any]) -> Iterable[types.ReasoningContentBlock]:
if "summary" not in block:
yield cast(types.ReasoningContentBlock, block)
return
if not block["summary"]:
_ = block.pop("summary", None)
yield block
yield cast(types.ReasoningContentBlock, block)
return
# Common part for every exploded line, except 'summary'
@ -364,7 +338,7 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
new_block["reasoning"] = part.get("text", "")
if idx == 0:
new_block.update(first_only)
yield new_block
yield cast(types.ReasoningContentBlock, new_block)
def _convert_to_v1_from_responses(
@ -374,7 +348,7 @@ def _convert_to_v1_from_responses(
) -> list[types.ContentBlock]:
"""Mutate a Responses message to v1 format."""
def _iter_blocks() -> Iterable[dict[str, Any]]:
def _iter_blocks() -> Iterable[types.ContentBlock]:
for block in content:
if not isinstance(block, dict):
continue
@ -385,7 +359,7 @@ def _convert_to_v1_from_responses(
block["annotations"] = [
_convert_annotation_to_v1(a) for a in block["annotations"]
]
yield block
yield cast(types.TextContentBlock, block)
elif block_type == "reasoning":
yield from _explode_reasoning(block)
@ -408,27 +382,29 @@ def _convert_to_v1_from_responses(
):
if extra_key in block:
new_block[extra_key] = block[extra_key]
yield new_block
yield cast(types.ImageContentBlock, new_block)
elif block_type == "function_call":
new_block = None
tool_call_block: Optional[types.ContentBlock] = None
call_id = block.get("call_id", "")
if call_id:
for tool_call in tool_calls or []:
if tool_call.get("id") == call_id:
new_block = tool_call.copy()
tool_call_block = cast(types.ToolCall, tool_call.copy())
break
else:
for invalid_tool_call in invalid_tool_calls or []:
if invalid_tool_call.get("id") == call_id:
new_block = invalid_tool_call.copy()
tool_call_block = cast(
types.InvalidToolCall, invalid_tool_call.copy()
)
break
if new_block:
if tool_call_block:
if "id" in block:
new_block["item_id"] = block["id"]
tool_call_block["item_id"] = block["id"]
if "index" in block:
new_block["index"] = block["index"]
yield new_block
tool_call_block["index"] = block["index"]
yield tool_call_block
elif block_type == "web_search_call":
web_search_call = {"type": "web_search_call", "id": block["id"]}
@ -448,8 +424,8 @@ def _convert_to_v1_from_responses(
web_search_result = {"type": "web_search_result", "id": block["id"]}
if "index" in block:
web_search_result["index"] = block["index"] + 1
yield web_search_call
yield web_search_result
yield cast(types.WebSearchCall, web_search_call)
yield cast(types.WebSearchResult, web_search_result)
elif block_type == "code_interpreter_call":
code_interpreter_call = {
@ -489,14 +465,14 @@ def _convert_to_v1_from_responses(
if "index" in block:
code_interpreter_result["index"] = block["index"] + 1
yield code_interpreter_call
yield code_interpreter_result
yield cast(types.CodeInterpreterCall, code_interpreter_call)
yield cast(types.CodeInterpreterResult, code_interpreter_result)
else:
new_block = {"type": "non_standard", "value": block}
if "index" in new_block["value"]:
new_block["index"] = new_block["value"].pop("index")
yield new_block
yield cast(types.NonStandardContentBlock, new_block)
return list(_iter_blocks())
@ -511,9 +487,9 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
if "title" in annotation:
new_ann["filename"] = annotation["title"]
if "file_id" in annotation:
new_ann["file_id"] = annotation["file_id"]
new_ann["file_id"] = annotation["file_id"] # type: ignore[typeddict-item]
if "file_index" in annotation:
new_ann["index"] = annotation["file_index"]
new_ann["index"] = annotation["file_index"] # type: ignore[typeddict-item]
return new_ann
@ -649,10 +625,11 @@ def _convert_from_v1_to_responses(
new_block = {"type": "function_call", "call_id": block["id"]}
if "item_id" in block:
new_block["id"] = block["item_id"] # type: ignore[typeddict-item]
if "name" in block and "arguments" in block:
if "name" in block:
new_block["name"] = block["name"]
if "arguments" in block:
new_block["arguments"] = block["arguments"] # type: ignore[typeddict-item]
else:
if any(key not in block for key in ("name", "arguments")):
matching_tool_calls = [
call for call in tool_calls if call["id"] == block["id"]
]

View File

@ -108,12 +108,7 @@ from langchain_openai.chat_models._client_utils import (
)
from langchain_openai.chat_models._compat import (
_convert_from_v03_ai_message,
_convert_from_v1_to_chat_completions,
_convert_from_v1_to_responses,
_convert_to_v03_ai_message,
_convert_to_v1_from_chat_completions,
_convert_to_v1_from_chat_completions_chunk,
_convert_to_v1_from_responses,
)
if TYPE_CHECKING:
@ -674,7 +669,7 @@ class BaseChatOpenAI(BaseChatModel):
.. versionadded:: 0.3.9
"""
output_version: str = "v0"
output_version: Literal["v0", "responses/v1"] = "v0"
"""Version of AIMessage output format to use.
This field is used to roll-out new output formats for chat model AIMessages
@ -685,7 +680,6 @@ class BaseChatOpenAI(BaseChatModel):
- ``'v0'``: AIMessage format as of langchain-openai 0.3.x.
- ``'responses/v1'``: Formats Responses API output
items into AIMessage content blocks.
- ``"v1"``: v1 of LangChain cross-provider standard.
Currently only impacts the Responses API. ``output_version='responses/v1'`` is
recommended.
@ -875,10 +869,6 @@ class BaseChatOpenAI(BaseChatModel):
message=default_chunk_class(content="", usage_metadata=usage_metadata),
generation_info=base_generation_info,
)
if self.output_version == "v1":
generation_chunk.message = _convert_to_v1_from_chat_completions_chunk(
cast(AIMessageChunk, generation_chunk.message)
)
return generation_chunk
choice = choices[0]
@ -906,20 +896,6 @@ class BaseChatOpenAI(BaseChatModel):
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
if self.output_version == "v1":
message_chunk = cast(AIMessageChunk, message_chunk)
# Convert to v1 format
if isinstance(message_chunk.content, str):
message_chunk = _convert_to_v1_from_chat_completions_chunk(
message_chunk
)
if message_chunk.content:
message_chunk.content[0]["index"] = 0 # type: ignore[index]
else:
message_chunk = _convert_to_v1_from_chat_completions_chunk(
message_chunk
)
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
@ -1212,12 +1188,7 @@ class BaseChatOpenAI(BaseChatModel):
else:
payload = _construct_responses_api_payload(messages, payload)
else:
payload["messages"] = [
_convert_message_to_dict(_convert_from_v1_to_chat_completions(m))
if isinstance(m, AIMessage)
else _convert_message_to_dict(m)
for m in messages
]
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload
def _create_chat_result(
@ -1283,11 +1254,6 @@ class BaseChatOpenAI(BaseChatModel):
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
if self.output_version == "v1":
_ = llm_output.pop("token_usage", None)
generations[0].message = _convert_to_v1_from_chat_completions(
cast(AIMessage, generations[0].message)
)
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
@ -3611,7 +3577,6 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
for lc_msg in messages:
if isinstance(lc_msg, AIMessage):
lc_msg = _convert_from_v03_ai_message(lc_msg)
lc_msg = _convert_from_v1_to_responses(lc_msg)
msg = _convert_message_to_dict(lc_msg)
# "name" parameter unsupported
if "name" in msg:
@ -3755,7 +3720,7 @@ def _construct_lc_result_from_responses_api(
response: Response,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
output_version: str = "v0",
output_version: Literal["v0", "responses/v1"] = "v0",
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
@ -3894,27 +3859,6 @@ def _construct_lc_result_from_responses_api(
)
if output_version == "v0":
message = _convert_to_v03_ai_message(message)
elif output_version == "v1":
message = _convert_to_v1_from_responses(message)
if response.tools and any(
tool.type == "image_generation" for tool in response.tools
):
# Get mime_time from tool definition and add to image generations
# if missing (primarily for tracing purposes).
image_generation_call = next(
tool for tool in response.tools if tool.type == "image_generation"
)
if image_generation_call.output_format:
mime_type = f"image/{image_generation_call.output_format}"
for content_block in message.content:
# OK to mutate output message
if (
isinstance(content_block, dict)
and content_block.get("type") == "image"
and "base64" in content_block
and "mime_type" not in block
):
block["mime_type"] = mime_type
else:
pass
return ChatResult(generations=[ChatGeneration(message=message)])
@ -3928,7 +3872,7 @@ def _convert_responses_chunk_to_generation_chunk(
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
output_version: str = "v0",
output_version: Literal["v0", "responses/v1"] = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
"""Advance indexes tracked during streaming.
@ -3994,29 +3938,9 @@ def _convert_responses_chunk_to_generation_chunk(
annotation = chunk.annotation
else:
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
if output_version == "v1":
content.append(
{
"type": "text",
"text": "",
"annotations": [annotation],
"index": current_index,
}
)
else:
content.append({"annotations": [annotation], "index": current_index})
content.append({"annotations": [annotation], "index": current_index})
elif chunk.type == "response.output_text.done":
if output_version == "v1":
content.append(
{
"type": "text",
"text": "",
"id": chunk.item_id,
"index": current_index,
}
)
else:
content.append({"id": chunk.item_id, "index": current_index})
content.append({"id": chunk.item_id, "index": current_index})
elif chunk.type == "response.created":
id = chunk.response.id
response_metadata["id"] = chunk.response.id # Backwards compatibility
@ -4092,34 +4016,21 @@ def _convert_responses_chunk_to_generation_chunk(
content.append({"type": "refusal", "refusal": chunk.refusal})
elif chunk.type == "response.output_item.added" and chunk.item.type == "reasoning":
_advance(chunk.output_index)
current_sub_index = 0
reasoning = chunk.item.model_dump(exclude_none=True, mode="json")
reasoning["index"] = current_index
content.append(reasoning)
elif chunk.type == "response.reasoning_summary_part.added":
if output_version in ("v0", "responses/v1"):
_advance(chunk.output_index)
content.append(
{
# langchain-core uses the `index` key to aggregate text blocks.
"summary": [
{
"index": chunk.summary_index,
"type": "summary_text",
"text": "",
}
],
"index": current_index,
"type": "reasoning",
}
)
else:
block: dict = {"type": "reasoning", "reasoning": ""}
if chunk.summary_index > 0:
_advance(chunk.output_index, chunk.summary_index)
block["id"] = chunk.item_id
block["index"] = current_index
content.append(block)
_advance(chunk.output_index)
content.append(
{
# langchain-core uses the `index` key to aggregate text blocks.
"summary": [
{"index": chunk.summary_index, "type": "summary_text", "text": ""}
],
"index": current_index,
"type": "reasoning",
}
)
elif chunk.type == "response.image_generation_call.partial_image":
# Partial images are not supported yet.
pass
@ -4154,15 +4065,6 @@ def _convert_responses_chunk_to_generation_chunk(
AIMessageChunk,
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
)
elif output_version == "v1":
message = cast(AIMessageChunk, _convert_to_v1_from_responses(message))
for content_block in message.content:
if (
isinstance(content_block, dict)
and content_block.get("index", -1) > current_index
):
# blocks were added for v1
current_index = content_block["index"]
else:
pass
return (

View File

@ -154,7 +154,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> MessageV1:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
content.extend(tool_calls)
content.extend(cast(list[ToolCall], tool_calls))
if audio := _dict.get("audio"):
# TODO: populate standard fields
content.append(
@ -3796,7 +3796,7 @@ def _convert_responses_chunk_to_generation_chunk(
for content_block in content_v1:
if (
isinstance(content_block, dict)
and content_block.get("index", -1) > current_index
and (content_block.get("index") or -1) > current_index # type: ignore[operator]
):
# blocks were added for v1
current_index = content_block["index"]

View File

@ -2,7 +2,7 @@
import json
import os
from typing import Annotated, Any, Literal, Optional, cast
from typing import Annotated, Any, Literal, Optional, Union, cast
import openai
import pytest
@ -25,7 +25,9 @@ from langchain_openai import ChatOpenAI, ChatOpenAIV1
MODEL_NAME = "gpt-4o-mini"
def _check_response(response: Optional[BaseMessage], output_version) -> None:
def _check_response(
response: Optional[Union[BaseMessage, AIMessageV1]], output_version: str
) -> None:
if output_version == "v1":
assert isinstance(response, AIMessageV1) or isinstance(
response, AIMessageChunkV1
@ -36,8 +38,8 @@ def _check_response(response: Optional[BaseMessage], output_version) -> None:
for block in response.content:
assert isinstance(block, dict)
if block["type"] == "text":
assert isinstance(block["text"], str)
for annotation in block["annotations"]:
assert isinstance(block["text"], str) # type: ignore[typeddict-item]
for annotation in block["annotations"]: # type: ignore[typeddict-item]
if annotation["type"] == "file_citation":
assert all(
key in annotation
@ -52,7 +54,7 @@ def _check_response(response: Optional[BaseMessage], output_version) -> None:
if output_version == "v1":
text_content = response.text
else:
text_content = response.text()
text_content = response.text() # type: ignore[operator,misc]
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
@ -60,7 +62,7 @@ def _check_response(response: Optional[BaseMessage], output_version) -> None:
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
assert response.response_metadata["service_tier"]
assert response.response_metadata["service_tier"] # type: ignore[typeddict-item]
@pytest.mark.default_cassette("test_web_search.yaml.gz")
@ -70,7 +72,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
if output_version == "v1":
llm = ChatOpenAIV1(model=MODEL_NAME)
else:
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version)
llm = ChatOpenAI(model=MODEL_NAME, output_version=output_version) # type: ignore[assignment]
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -87,7 +89,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
assert isinstance(chunk, AIMessageChunkV1)
full = chunk if full is None else full + chunk
else:
full: Optional[BaseMessageChunk] = None
full: Optional[BaseMessageChunk] = None # type: ignore[no-redef]
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -100,7 +102,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
response = llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
previous_response_id=first_response.response_metadata["id"],
previous_response_id=first_response.response_metadata["id"], # type: ignore[typeddict-item]
)
_check_response(response, output_version)
@ -122,6 +124,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None:
_check_response(response, output_version)
for msg in [first_response, full, response]:
assert msg is not None
block_types = [block["type"] for block in msg.content] # type: ignore[index]
if output_version == "responses/v1":
assert block_types == ["web_search_call", "text"]
@ -246,6 +249,7 @@ def test_parsed_pydantic_schema(output_version: Literal["v0", "responses/v1"]) -
def test_parsed_pydantic_schema_v1() -> None:
llm = ChatOpenAIV1(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are ya", response_format=Foo)
assert response.text
parsed = Foo(**json.loads(response.text))
assert parsed == response.parsed
assert parsed.response
@ -258,6 +262,7 @@ def test_parsed_pydantic_schema_v1() -> None:
full = chunk if full is None else full + chunk
chunks.append(chunk)
assert isinstance(full, AIMessageChunkV1)
assert full.text
parsed = Foo(**json.loads(full.text))
assert parsed == full.parsed
assert parsed.response
@ -649,7 +654,7 @@ def test_code_interpreter_v1() -> None:
# Test streaming
# Use same container
container_id = tool_outputs[0]["container_id"]
container_id = tool_outputs[0]["container_id"] # type: ignore[typeddict-item]
llm_with_tools = llm.bind_tools(
[{"type": "code_interpreter", "container": container_id}]
)
@ -815,7 +820,9 @@ def test_mcp_builtin_zdr_v1() -> None:
@pytest.mark.default_cassette("test_image_generation_streaming.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_streaming(output_version: str) -> None:
def test_image_generation_streaming(
output_version: Literal["v0", "responses/v1"],
) -> None:
"""Test image generation streaming."""
llm = ChatOpenAI(
model="gpt-4.1", use_responses_api=True, output_version=output_version
@ -934,7 +941,9 @@ def test_image_generation_streaming_v1() -> None:
@pytest.mark.default_cassette("test_image_generation_multi_turn.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_image_generation_multi_turn(output_version: str) -> None:
def test_image_generation_multi_turn(
output_version: Literal["v0", "responses/v1"],
) -> None:
"""Test multi-turn editing of image generation by passing in history."""
# Test multi-turn
llm = ChatOpenAI(

View File

@ -20,7 +20,9 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.base import BaseTracer
@ -54,7 +56,6 @@ from langchain_openai.chat_models._compat import (
_convert_from_v1_to_chat_completions,
_convert_from_v1_to_responses,
_convert_to_v03_ai_message,
_convert_to_v1_from_chat_completions,
_convert_to_v1_from_responses,
)
from langchain_openai.chat_models.base import (
@ -2301,7 +2302,7 @@ def test_mcp_tracing() -> None:
assert payload["tools"][0]["headers"]["Authorization"] == "Bearer PLACEHOLDER"
def test_compat_responses_v1() -> None:
def test_compat_responses_v03() -> None:
# Check compatibility with v0.3 message format
message_v03 = AIMessage(
content=[
@ -2366,265 +2367,147 @@ def test_compat_responses_v1() -> None:
"message_v1, expected",
[
(
AIMessage(
AIMessageV1(
[
{"type": "reasoning", "reasoning": "Reasoning text"},
{"type": "tool_call", "id": "call_123"},
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
"type": "text",
"text": "Hello, world!",
"annotations": [
{"type": "url_citation", "url": "https://example.com"}
{"type": "citation", "url": "https://example.com"}
],
},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
response_metadata={"model_provider": "openai", "model_name": "gpt-4.1"},
),
AIMessage(
AIMessageV1(
[{"type": "text", "text": "Hello, world!"}],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
response_metadata={"model_provider": "openai", "model_name": "gpt-4.1"},
),
)
],
)
def test_convert_from_v1_to_chat_completions(
message_v1: AIMessage, expected: AIMessage
message_v1: AIMessageV1, expected: AIMessageV1
) -> None:
result = _convert_from_v1_to_chat_completions(message_v1)
assert result == expected
assert result.tool_calls == message_v1.tool_calls # tool calls remain cached
# Check no mutation
assert message_v1 != result
@pytest.mark.parametrize(
"message_chat_completions, expected",
[
(
AIMessage(
"Hello, world!", id="chatcmpl-123", response_metadata={"foo": "bar"}
),
AIMessage(
[{"type": "text", "text": "Hello, world!"}],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
),
),
(
AIMessage(
[{"type": "text", "text": "Hello, world!"}],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
),
AIMessage(
[
{"type": "text", "text": "Hello, world!"},
{"type": "tool_call", "id": "call_123"},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
),
),
(
AIMessage(
"",
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
additional_kwargs={"tool_calls": [{"foo": "bar"}]},
),
AIMessage(
[{"type": "tool_call", "id": "call_123"}],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
id="chatcmpl-123",
response_metadata={"foo": "bar"},
),
),
],
)
def test_convert_to_v1_from_chat_completions(
message_chat_completions: AIMessage, expected: AIMessage
) -> None:
result = _convert_to_v1_from_chat_completions(message_chat_completions)
assert result == expected
@pytest.mark.parametrize(
"message_v1, expected",
[
(
AIMessage(
AIMessageV1(
[
{"type": "reasoning", "id": "abc123"},
{"type": "reasoning", "id": "abc234", "reasoning": "foo "},
{"type": "reasoning", "id": "abc234", "reasoning": "bar"},
{"type": "tool_call", "id": "call_123"},
{
"type": "tool_call",
"id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"item_id": "fc_123",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
cast(
ToolCall,
{
"type": "tool_call",
"id": "call_234",
"name": "get_weather_2",
"args": {"location": "New York"},
"item_id": "fc_123",
},
),
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
{
"type": "document_citation",
"title": "my doc",
"index": 1,
"file_id": "file_123",
},
{"type": "citation", "url": "https://example.com"},
cast(
types.Citation,
{
"type": "citation",
"title": "my doc",
"file_index": 1,
"file_id": "file_123",
},
),
{
"type": "non_standard_annotation",
"value": {"bar": "baz"},
},
],
},
{"type": "image", "base64": "...", "id": "img_123"},
{"type": "image", "base64": "...", "id": "ig_123"},
{
"type": "non_standard",
"value": {"type": "something_else", "foo": "bar"},
},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
# Make values different to check we pull from content when
# available
"type": "tool_call",
"id": "call_234",
"name": "get_weather_3",
"args": {"location": "Boston"},
},
],
id="resp123",
response_metadata={"foo": "bar"},
),
AIMessage(
[
{"type": "reasoning", "id": "abc123", "summary": []},
{
"type": "reasoning",
"id": "abc234",
"summary": [
{"type": "summary_text", "text": "foo "},
{"type": "summary_text", "text": "bar"},
],
},
{
"type": "function_call",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
{
"type": "function_call",
"call_id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"id": "fc_123",
},
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
{
"type": "file_citation",
"filename": "my doc",
"index": 1,
"file_id": "file_123",
},
{"bar": "baz"},
],
},
{"type": "image_generation_call", "id": "img_123", "result": "..."},
{"type": "something_else", "foo": "bar"},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
# Make values different to check we pull from content when
# available
"type": "tool_call",
"id": "call_234",
"name": "get_weather_3",
"args": {"location": "Boston"},
},
],
id="resp123",
response_metadata={"foo": "bar"},
),
[
{"type": "reasoning", "id": "abc123", "summary": []},
{
"type": "reasoning",
"id": "abc234",
"summary": [
{"type": "summary_text", "text": "foo "},
{"type": "summary_text", "text": "bar"},
],
},
{
"type": "function_call",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
{
"type": "function_call",
"call_id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"id": "fc_123",
},
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
{
"type": "file_citation",
"filename": "my doc",
"index": 1,
"file_id": "file_123",
},
{"bar": "baz"},
],
},
{"type": "image_generation_call", "id": "ig_123", "result": "..."},
{"type": "something_else", "foo": "bar"},
],
)
],
)
def test_convert_from_v1_to_responses(
message_v1: AIMessage, expected: AIMessage
message_v1: AIMessageV1, expected: AIMessageV1
) -> None:
result = _convert_from_v1_to_responses(message_v1)
result = _convert_from_v1_to_responses(message_v1.content, message_v1.tool_calls)
assert result == expected
# Check no mutation
@ -2632,139 +2515,118 @@ def test_convert_from_v1_to_responses(
@pytest.mark.parametrize(
"message_responses, expected",
"responses_content, tool_calls, expected_content",
[
(
AIMessage(
[
{"type": "reasoning", "id": "abc123"},
{
"type": "reasoning",
"id": "abc234",
"summary": [
{"type": "summary_text", "text": "foo "},
{"type": "summary_text", "text": "bar"},
],
},
{
"type": "function_call",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
{
"type": "function_call",
"call_id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"id": "fc_123",
},
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
{
"type": "file_citation",
"filename": "my doc",
"index": 1,
"file_id": "file_123",
},
{"bar": "baz"},
],
},
{"type": "image_generation_call", "id": "img_123", "result": "..."},
{"type": "something_else", "foo": "bar"},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
# Make values different to check we pull from content when
# available
"type": "tool_call",
"id": "call_234",
"name": "get_weather_3",
"args": {"location": "Boston"},
},
],
id="resp123",
response_metadata={"foo": "bar"},
),
AIMessage(
[
{"type": "reasoning", "id": "abc123"},
{"type": "reasoning", "id": "abc234", "reasoning": "foo "},
{"type": "reasoning", "id": "abc234", "reasoning": "bar"},
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
[
{"type": "reasoning", "id": "abc123", "summary": []},
{
"type": "reasoning",
"id": "abc234",
"summary": [
{"type": "summary_text", "text": "foo "},
{"type": "summary_text", "text": "bar"},
],
},
{
"type": "function_call",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
},
{
"type": "function_call",
"call_id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"id": "fc_123",
},
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
{
"type": "file_citation",
"filename": "my doc",
"index": 1,
"file_id": "file_123",
},
{"bar": "baz"},
],
},
{"type": "image_generation_call", "id": "ig_123", "result": "..."},
{"type": "something_else", "foo": "bar"},
],
[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
"type": "tool_call",
"id": "call_234",
"name": "get_weather_2",
"args": {"location": "New York"},
},
],
[
{"type": "reasoning", "id": "abc123"},
{"type": "reasoning", "id": "abc234", "reasoning": "foo "},
{"type": "reasoning", "id": "abc234", "reasoning": "bar"},
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
cast(
ToolCall,
{
"type": "tool_call",
"id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"args": {"location": "New York"},
"item_id": "fc_123",
},
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "url_citation", "url": "https://example.com"},
),
{"type": "text", "text": "Hello "},
{
"type": "text",
"text": "world",
"annotations": [
{"type": "citation", "url": "https://example.com"},
cast(
types.Citation,
{
"type": "document_citation",
"type": "citation",
"title": "my doc",
"index": 1,
"file_index": 1,
"file_id": "file_123",
},
{
"type": "non_standard_annotation",
"value": {"bar": "baz"},
},
],
},
{"type": "image", "base64": "...", "id": "img_123"},
{
"type": "non_standard",
"value": {"type": "something_else", "foo": "bar"},
},
],
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
},
{
# Make values different to check we pull from content when
# available
"type": "tool_call",
"id": "call_234",
"name": "get_weather_3",
"args": {"location": "Boston"},
},
],
id="resp123",
response_metadata={"foo": "bar"},
),
),
{"type": "non_standard_annotation", "value": {"bar": "baz"}},
],
},
{"type": "image", "base64": "...", "id": "ig_123"},
{
"type": "non_standard",
"value": {"type": "something_else", "foo": "bar"},
},
],
)
],
)
def test_convert_to_v1_from_responses(
message_responses: AIMessage, expected: AIMessage
responses_content: list[dict[str, Any]],
tool_calls: list[ToolCall],
expected_content: list[types.ContentBlock],
) -> None:
result = _convert_to_v1_from_responses(message_responses)
assert result == expected
result = _convert_to_v1_from_responses(responses_content, tool_calls)
assert result == expected_content
def test_get_last_messages() -> None:

View File

@ -1,6 +1,6 @@
from langchain_openai.chat_models import __all__
EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI"]
EXPECTED_ALL = ["ChatOpenAI", "ChatOpenAIV1", "AzureChatOpenAI"]
def test_all_imports() -> None:

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Literal, Optional
from unittest.mock import MagicMock, patch
import pytest
@ -698,7 +698,9 @@ def _strip_none(obj: Any) -> Any:
),
],
)
def test_responses_stream(output_version: str, expected_content: list[dict]) -> None:
def test_responses_stream(
output_version: Literal["v0", "responses/v1"], expected_content: list[dict]
) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)

View File

@ -6,6 +6,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.messages.v1 import MessageV1
from pydantic import BaseModel
@ -196,7 +197,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@ -3,6 +3,7 @@ from langchain_openai import __all__
EXPECTED_ALL = [
"OpenAI",
"ChatOpenAI",
"ChatOpenAIV1",
"OpenAIEmbeddings",
"AzureOpenAI",
"AzureChatOpenAI",