openai[patch]: allow specification of output format for Responses API (#31686)

This commit is contained in:
ccurme 2025-06-26 13:41:43 -04:00 committed by GitHub
parent 59c2b81627
commit 88d5f3edcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 328 additions and 35 deletions

View File

@ -113,6 +113,7 @@ def test_configurable() -> None:
"openai_api_base": None,
"openai_organization": None,
"openai_proxy": None,
"output_version": "v0",
"request_timeout": None,
"max_retries": None,
"presence_penalty": None,

View File

@ -128,6 +128,8 @@ def _convert_to_v03_ai_message(
else:
new_content.append(block)
message.content = new_content
if isinstance(message.id, str) and message.id.startswith("resp_"):
message.id = None
else:
pass
@ -137,13 +139,29 @@ def _convert_to_v03_ai_message(
def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
"""Convert an old-style v0.3 AIMessage into the new content-block format."""
# Only update ChatOpenAI v0.3 AIMessages
if not (
# TODO: structure provenance into AIMessage
is_chatopenai_v03 = (
isinstance(message.content, list)
and all(isinstance(b, dict) for b in message.content)
) or not any(
item in message.additional_kwargs
for item in ["reasoning", "tool_outputs", "refusal", _FUNCTION_CALL_IDS_MAP_KEY]
):
) and (
any(
item in message.additional_kwargs
for item in [
"reasoning",
"tool_outputs",
"refusal",
_FUNCTION_CALL_IDS_MAP_KEY,
]
)
or (
isinstance(message.id, str)
and message.id.startswith("msg_")
and (response_id := message.response_metadata.get("id"))
and isinstance(response_id, str)
and response_id.startswith("resp_")
)
)
if not is_chatopenai_v03:
return message
content_order = [

View File

@ -649,6 +649,25 @@ class BaseChatOpenAI(BaseChatModel):
.. versionadded:: 0.3.9
"""
output_version: Literal["v0", "responses/v1"] = "v0"
"""Version of AIMessage output format to use.
This field is used to roll-out new output formats for chat model AIMessages
in a backwards-compatible way.
Supported values:
- ``"v0"``: AIMessage format as of langchain-openai 0.3.x.
- ``"responses/v1"``: Formats Responses API output
items into AIMessage content blocks.
Currently only impacts the Responses API. ``output_version="responses/v1"`` is
recommended.
.. versionadded:: 0.3.25
"""
model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before")
@ -903,6 +922,7 @@ class BaseChatOpenAI(BaseChatModel):
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
@ -957,6 +977,7 @@ class BaseChatOpenAI(BaseChatModel):
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
@ -1096,7 +1117,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj, metadata=generation_info
response,
schema=original_schema_obj,
metadata=generation_info,
output_version=self.output_version,
)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
@ -1109,6 +1133,8 @@ class BaseChatOpenAI(BaseChatModel):
def _use_responses_api(self, payload: dict) -> bool:
if isinstance(self.use_responses_api, bool):
return self.use_responses_api
elif self.output_version == "responses/v1":
return True
elif self.include is not None:
return True
elif self.reasoning is not None:
@ -1327,7 +1353,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj, metadata=generation_info
response,
schema=original_schema_obj,
metadata=generation_info,
output_version=self.output_version,
)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
@ -3540,6 +3569,7 @@ def _construct_lc_result_from_responses_api(
response: Response,
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
output_version: Literal["v0", "responses/v1"] = "v0",
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
@ -3676,7 +3706,10 @@ def _construct_lc_result_from_responses_api(
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
message = _convert_to_v03_ai_message(message)
if output_version == "v0":
message = _convert_to_v03_ai_message(message)
else:
pass
return ChatResult(generations=[ChatGeneration(message=message)])
@ -3688,6 +3721,7 @@ def _convert_responses_chunk_to_generation_chunk(
schema: Optional[type[_BM]] = None,
metadata: Optional[dict] = None,
has_reasoning: bool = False,
output_version: Literal["v0", "responses/v1"] = "v0",
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
"""Advance indexes tracked during streaming.
@ -3756,12 +3790,15 @@ def _convert_responses_chunk_to_generation_chunk(
elif chunk.type == "response.output_text.done":
content.append({"id": chunk.item_id, "index": current_index})
elif chunk.type == "response.created":
response_metadata["id"] = chunk.response.id
id = chunk.response.id
response_metadata["id"] = chunk.response.id # Backwards compatibility
elif chunk.type == "response.completed":
msg = cast(
AIMessage,
(
_construct_lc_result_from_responses_api(chunk.response, schema=schema)
_construct_lc_result_from_responses_api(
chunk.response, schema=schema, output_version=output_version
)
.generations[0]
.message
),
@ -3773,7 +3810,10 @@ def _convert_responses_chunk_to_generation_chunk(
k: v for k, v in msg.response_metadata.items() if k != "id"
}
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
id = chunk.item.id
if output_version == "v0":
id = chunk.item.id
else:
pass
elif (
chunk.type == "response.output_item.added"
and chunk.item.type == "function_call"
@ -3868,9 +3908,13 @@ def _convert_responses_chunk_to_generation_chunk(
additional_kwargs=additional_kwargs,
id=id,
)
message = cast(
AIMessageChunk, _convert_to_v03_ai_message(message, has_reasoning=has_reasoning)
)
if output_version == "v0":
message = cast(
AIMessageChunk,
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
)
else:
pass
return (
current_index,
current_output_index,

View File

@ -2,7 +2,7 @@
import json
import os
from typing import Annotated, Any, Optional, cast
from typing import Annotated, Any, Literal, Optional, cast
import openai
import pytest
@ -50,15 +50,11 @@ def _check_response(response: Optional[BaseMessage]) -> None:
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
assert response.response_metadata["service_tier"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
@pytest.mark.vcr
def test_web_search() -> None:
llm = ChatOpenAI(model=MODEL_NAME)
llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1")
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -111,6 +107,11 @@ def test_web_search() -> None:
)
_check_response(response)
for msg in [first_response, full, response]:
assert isinstance(msg, AIMessage)
block_types = [block["type"] for block in msg.content] # type: ignore[index]
assert block_types == ["web_search_call", "text"]
@pytest.mark.flaky(retries=3, delay=1)
async def test_web_search_async() -> None:
@ -133,6 +134,12 @@ async def test_web_search_async() -> None:
assert isinstance(full, AIMessageChunk)
_check_response(full)
for msg in [response, full]:
assert msg.additional_kwargs["tool_outputs"]
assert len(msg.additional_kwargs["tool_outputs"]) == 1
tool_output = msg.additional_kwargs["tool_outputs"][0]
assert tool_output["type"] == "web_search_call"
@pytest.mark.flaky(retries=3, delay=1)
def test_function_calling() -> None:
@ -288,20 +295,32 @@ def test_function_calling_and_structured_output() -> None:
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
def test_reasoning() -> None:
llm = ChatOpenAI(model="o3-mini", use_responses_api=True)
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)
response = llm.invoke("Hello", reasoning={"effort": "low"})
assert isinstance(response, AIMessage)
assert response.additional_kwargs["reasoning"]
# Test init params + streaming
llm = ChatOpenAI(model="o3-mini", reasoning_effort="low", use_responses_api=True)
llm = ChatOpenAI(
model="o4-mini", reasoning={"effort": "low"}, output_version=output_version
)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
assert full.additional_kwargs["reasoning"]
for msg in [response, full]:
if output_version == "v0":
assert msg.additional_kwargs["reasoning"]
else:
block_types = [block["type"] for block in msg.content]
assert block_types == ["reasoning", "text"]
def test_stateful_api() -> None:
@ -355,20 +374,37 @@ def test_file_search() -> None:
_check_response(full)
def test_stream_reasoning_summary() -> None:
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
def test_stream_reasoning_summary(
output_version: Literal["v0", "responses/v1"],
) -> None:
llm = ChatOpenAI(
model="o4-mini",
# Routes to Responses API if `reasoning` is set.
reasoning={"effort": "medium", "summary": "auto"},
output_version=output_version,
)
message_1 = {"role": "user", "content": "What is 3^3?"}
message_1 = {
"role": "user",
"content": "What was the third tallest buliding in the year 2000?",
}
response_1: Optional[BaseMessageChunk] = None
for chunk in llm.stream([message_1]):
assert isinstance(chunk, AIMessageChunk)
response_1 = chunk if response_1 is None else response_1 + chunk
assert isinstance(response_1, AIMessageChunk)
reasoning = response_1.additional_kwargs["reasoning"]
assert set(reasoning.keys()) == {"id", "type", "summary"}
if output_version == "v0":
reasoning = response_1.additional_kwargs["reasoning"]
assert set(reasoning.keys()) == {"id", "type", "summary"}
else:
reasoning = next(
block
for block in response_1.content
if block["type"] == "reasoning" # type: ignore[index]
)
assert set(reasoning.keys()) == {"id", "type", "summary", "index"}
summary = reasoning["summary"]
assert isinstance(summary, list)
for block in summary:
@ -462,11 +498,11 @@ def test_mcp_builtin() -> None:
)
@pytest.mark.skip
@pytest.mark.vcr
def test_mcp_builtin_zdr() -> None:
llm = ChatOpenAI(
model="o4-mini",
use_responses_api=True,
output_version="responses/v1",
store=False,
include=["reasoning.encrypted_content"],
)

View File

@ -24,6 +24,7 @@
}),
'openai_api_type': 'azure',
'openai_api_version': '2021-10-01',
'output_version': 'v0',
'request_timeout': 60.0,
'stop': list([
]),

View File

@ -18,6 +18,7 @@
'lc': 1,
'type': 'secret',
}),
'output_version': 'v0',
'request_timeout': 60.0,
'stop': list([
]),

View File

@ -18,6 +18,7 @@
'lc': 1,
'type': 'secret',
}),
'output_version': 'v0',
'request_timeout': 60.0,
'stop': list([
]),

View File

@ -1192,6 +1192,7 @@ def test__construct_lc_result_from_responses_api_basic_text_response() -> None:
),
)
# v0
result = _construct_lc_result_from_responses_api(response)
assert isinstance(result, ChatResult)
@ -1209,6 +1210,16 @@ def test__construct_lc_result_from_responses_api_basic_text_response() -> None:
assert result.generations[0].message.response_metadata["id"] == "resp_123"
assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o"
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{"type": "text", "text": "Hello, world!", "annotations": [], "id": "msg_123"}
]
assert result.generations[0].message.id == "resp_123"
assert result.generations[0].message.response_metadata["id"] == "resp_123"
def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None:
"""Test a response with multiple text blocks."""
@ -1284,6 +1295,7 @@ def test__construct_lc_result_from_responses_api_multiple_messages() -> None:
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
assert result.generations[0].message.content == [
@ -1297,6 +1309,23 @@ def test__construct_lc_result_from_responses_api_multiple_messages() -> None:
"id": "rs_123",
}
}
assert result.generations[0].message.id == "msg_234"
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{"type": "text", "text": "foo", "annotations": [], "id": "msg_123"},
{
"type": "reasoning",
"summary": [{"type": "summary_text", "text": "reasoning foo"}],
"id": "rs_123",
},
{"type": "text", "text": "bar", "annotations": [], "id": "msg_234"},
]
assert result.generations[0].message.id == "resp_123"
def test__construct_lc_result_from_responses_api_refusal_response() -> None:
@ -1324,12 +1353,25 @@ def test__construct_lc_result_from_responses_api_refusal_response() -> None:
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
assert result.generations[0].message.additional_kwargs["refusal"] == (
"I cannot assist with that request."
)
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{
"type": "refusal",
"refusal": "I cannot assist with that request.",
"id": "msg_123",
}
]
def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None:
"""Test a response with a valid function call."""
@ -1352,6 +1394,7 @@ def test__construct_lc_result_from_responses_api_function_call_valid_json() -> N
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message)
@ -1368,6 +1411,22 @@ def test__construct_lc_result_from_responses_api_function_call_valid_json() -> N
== "func_123"
)
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
msg = cast(AIMessage, result.generations[0].message)
assert msg.tool_calls
assert msg.content == [
{
"type": "function_call",
"id": "func_123",
"name": "get_weather",
"arguments": '{"location": "New York", "unit": "celsius"}',
"call_id": "call_123",
}
]
def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None:
"""Test a response with an invalid JSON function call."""
@ -1444,6 +1503,7 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None:
user="user_123",
)
# v0
result = _construct_lc_result_from_responses_api(response)
# Check message content
@ -1472,6 +1532,28 @@ def test__construct_lc_result_from_responses_api_complex_response() -> None:
assert result.generations[0].message.response_metadata["status"] == "completed"
assert result.generations[0].message.response_metadata["user"] == "user_123"
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
msg = cast(AIMessage, result.generations[0].message)
assert msg.response_metadata["metadata"] == {"key1": "value1", "key2": "value2"}
assert msg.content == [
{
"type": "text",
"text": "Here's the information you requested:",
"annotations": [],
"id": "msg_123",
},
{
"type": "function_call",
"id": "func_123",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "New York"}',
},
]
def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None:
"""Test a response without usage metadata."""
@ -1525,6 +1607,7 @@ def test__construct_lc_result_from_responses_api_web_search_response() -> None:
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
@ -1542,6 +1625,14 @@ def test__construct_lc_result_from_responses_api_web_search_response() -> None:
== "completed"
)
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{"type": "web_search_call", "id": "websearch_123", "status": "completed"}
]
def test__construct_lc_result_from_responses_api_file_search_response() -> None:
"""Test a response with file search output."""
@ -1572,6 +1663,7 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None:
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
@ -1612,6 +1704,28 @@ def test__construct_lc_result_from_responses_api_file_search_response() -> None:
== 0.95
)
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{
"type": "file_search_call",
"id": "filesearch_123",
"status": "completed",
"queries": ["python code", "langchain"],
"results": [
{
"file_id": "file_123",
"filename": "example.py",
"score": 0.95,
"text": "def hello_world() -> None:\n print('Hello, world!')",
"attributes": {"language": "python", "size": 42},
}
],
}
]
def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
"""Test a response with both web search and file search outputs."""
@ -1656,6 +1770,7 @@ def test__construct_lc_result_from_responses_api_mixed_search_responses() -> Non
],
)
# v0
result = _construct_lc_result_from_responses_api(response)
# Check message content
@ -1686,6 +1801,34 @@ def test__construct_lc_result_from_responses_api_mixed_search_responses() -> Non
assert file_search["queries"] == ["python code"]
assert file_search["results"][0]["filename"] == "example.py"
# responses/v1
result = _construct_lc_result_from_responses_api(
response, output_version="responses/v1"
)
assert result.generations[0].message.content == [
{
"type": "text",
"text": "Here's what I found:",
"annotations": [],
"id": "msg_123",
},
{"type": "web_search_call", "id": "websearch_123", "status": "completed"},
{
"type": "file_search_call",
"id": "filesearch_123",
"queries": ["python code"],
"results": [
{
"file_id": "file_123",
"filename": "example.py",
"score": 0.95,
"text": "def hello_world() -> None:\n print('Hello, world!')",
}
],
"status": "completed",
},
]
def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> (
None
@ -1706,7 +1849,29 @@ def test__construct_responses_api_input_human_message_with_text_blocks_conversio
def test__construct_responses_api_input_multiple_message_components() -> None:
"""Test that human messages with text blocks are properly converted."""
messages: list = [
# v0
messages = [
AIMessage(
content=[{"type": "text", "text": "foo"}, {"type": "text", "text": "bar"}],
id="msg_123",
response_metadata={"id": "resp_123"},
)
]
result = _construct_responses_api_input(messages)
assert result == [
{
"type": "message",
"role": "assistant",
"content": [
{"type": "output_text", "text": "foo", "annotations": []},
{"type": "output_text", "text": "bar", "annotations": []},
],
"id": "msg_123",
}
]
# responses/v1
messages = [
AIMessage(
content=[
{"type": "text", "text": "foo", "id": "msg_123"},

View File

@ -1,7 +1,6 @@
from typing import Any, Optional
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from openai.types.responses import (
ResponseCompletedEvent,
@ -601,9 +600,18 @@ responses_stream = [
]
@pytest.mark.xfail(reason="Will be fixed with output format flags.")
def _strip_none(obj: Any) -> Any:
"""Recursively strip None values from dictionaries and lists."""
if isinstance(obj, dict):
return {k: _strip_none(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [_strip_none(v) for v in obj]
else:
return obj
def test_responses_stream() -> None:
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
llm = ChatOpenAI(model="o4-mini", output_version="responses/v1")
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
@ -644,3 +652,20 @@ def test_responses_stream() -> None:
]
assert full.content == expected_content
assert full.additional_kwargs == {}
assert full.id == "resp_123"
# Test reconstruction
payload = llm._get_request_payload([full])
completed = [
item
for item in responses_stream
if item.type == "response.completed" # type: ignore[attr-defined]
]
assert len(completed) == 1
response = completed[0].response # type: ignore[attr-defined]
assert len(response.output) == len(payload["input"])
for idx, item in enumerate(response.output):
dumped = _strip_none(item.model_dump())
_ = dumped.pop("status", None)
assert dumped == payload["input"][idx]

View File

@ -10,6 +10,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'output_version': 'v0',
'request_timeout': 60.0,
'stop': list([
]),