openai[patch]: support structured output via Responses API (#30265)

Also runs all standard tests using Responses API.
This commit is contained in:
ccurme 2025-03-14 15:14:23 -04:00 committed by GitHub
parent f54f14b747
commit c74e7b997d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 308 additions and 50 deletions

View File

@ -751,11 +751,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = self.root_client.responses.create(**payload)
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
chunk, schema=original_schema_obj
):
if run_manager:
run_manager.on_llm_new_token(
@ -773,11 +774,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = await self.root_async_client.responses.create(**payload)
original_schema_obj = kwargs.get("response_format")
async with context_manager as response:
async for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
chunk, schema=original_schema_obj
):
if run_manager:
await run_manager.on_llm_new_token(
@ -880,8 +882,14 @@ class BaseChatOpenAI(BaseChatModel):
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = self.root_client.responses.parse(**payload)
else:
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
@ -1062,8 +1070,15 @@ class BaseChatOpenAI(BaseChatModel):
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = await self.root_async_client.responses.parse(**payload)
else:
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
)
else:
response = await self.async_client.create(**payload)
return await run_in_executor(
@ -2833,23 +2848,45 @@ def _construct_responses_api_payload(
if tool_choice := payload.pop("tool_choice", None):
# chat api: {"type": "function", "function": {"name": "..."}}
# responses api: {"type": "function", "name": "..."}
if tool_choice["type"] == "function" and "function" in tool_choice:
if (
isinstance(tool_choice, dict)
and tool_choice["type"] == "function"
and "function" in tool_choice
):
payload["tool_choice"] = {"type": "function", **tool_choice["function"]}
else:
payload["tool_choice"] = tool_choice
if response_format := payload.pop("response_format", None):
# Structured output
if schema := payload.pop("response_format", None):
if payload.get("text"):
text = payload["text"]
raise ValueError(
"Can specify at most one of 'response_format' or 'text', received both:"
f"\n{response_format=}\n{text=}"
f"\n{schema=}\n{text=}"
)
# chat api: {"type": "json_schema, "json_schema": {"schema": {...}, "name": "...", "description": "...", "strict": ...}} # noqa: E501
# responses api: {"type": "json_schema, "schema": {...}, "name": "...", "description": "...", "strict": ...} # noqa: E501
if response_format["type"] == "json_schema":
payload["text"] = {"type": "json_schema", **response_format["json_schema"]}
# For pydantic + non-streaming case, we use responses.parse.
# Otherwise, we use responses.create.
if not payload.get("stream") and _is_pydantic_class(schema):
payload["text_format"] = schema
else:
payload["text"] = response_format
if _is_pydantic_class(schema):
schema_dict = schema.model_json_schema()
else:
schema_dict = schema
if schema_dict == {"type": "json_object"}: # JSON mode
payload["text"] = {"format": {"type": "json_object"}}
elif (
(response_format := _convert_to_openai_response_format(schema_dict))
and (isinstance(response_format, dict))
and (response_format["type"] == "json_schema")
):
payload["text"] = {
"format": {"type": "json_schema", **response_format["json_schema"]}
}
else:
pass
return payload
@ -2857,6 +2894,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
# "name" parameter unsupported
if "name" in msg:
msg.pop("name")
if msg["role"] == "tool":
tool_output = msg["content"]
if not isinstance(tool_output, str):
@ -2872,17 +2912,20 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
if tool_calls := msg.pop("tool_calls", None):
# TODO: should you be able to preserve the function call object id on
# the langchain tool calls themselves?
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
raise ValueError("")
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
function_call_ids = lc_msg.additional_kwargs.get(
_FUNCTION_CALL_IDS_MAP_KEY
)
for tool_call in tool_calls:
function_call = {
"type": "function_call",
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
"call_id": tool_call["id"],
"id": function_call_ids[tool_call["id"]],
}
if function_call_ids is not None and (
_id := function_call_ids.get(tool_call["id"])
):
function_call["id"] = _id
function_calls.append(function_call)
msg["content"] = msg.get("content") or []
@ -2949,7 +2992,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
return input_
def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
def _construct_lc_result_from_responses_api(
response: Response, schema: Optional[Type[_BM]] = None
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
raise ValueError(response.error)
@ -2994,6 +3039,8 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
],
}
content_blocks.append(block)
if hasattr(content, "parsed"):
additional_kwargs["parsed"] = content.parsed
if content.type == "refusal":
additional_kwargs["refusal"] = content.refusal
msg_id = output.id
@ -3034,6 +3081,35 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
additional_kwargs["tool_outputs"].append(tool_output)
else:
additional_kwargs["tool_outputs"] = [tool_output]
# Workaround for parsing structured output in the streaming case.
# from openai import OpenAI
# from pydantic import BaseModel
# class Foo(BaseModel):
# response: str
# client = OpenAI()
# client.responses.parse(
# model="gpt-4o-mini",
# input=[{"content": "how are ya", "role": "user"}],
# text_format=Foo,
# stream=True, # <-- errors
# )
if (
schema is not None
and "parsed" not in additional_kwargs
and response.text
and (text_config := response.text.model_dump())
and (format_ := text_config.get("format", {}))
and (format_.get("type") == "json_schema")
):
parsed_dict = json.loads(response.output_text)
if schema and _is_pydantic_class(schema):
parsed = schema(**parsed_dict)
else:
parsed = parsed_dict
additional_kwargs["parsed"] = parsed
message = AIMessage(
content=content_blocks,
id=msg_id,
@ -3047,7 +3123,7 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
def _convert_responses_chunk_to_generation_chunk(
chunk: Any,
chunk: Any, schema: Optional[Type[_BM]] = None
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []
@ -3074,11 +3150,13 @@ def _convert_responses_chunk_to_generation_chunk(
msg = cast(
AIMessage,
(
_construct_lc_result_from_responses_api(chunk.response)
_construct_lc_result_from_responses_api(chunk.response, schema=schema)
.generations[0]
.message
),
)
if parsed := msg.additional_kwargs.get("parsed"):
additional_kwargs["parsed"] = parsed
usage_metadata = msg.usage_metadata
response_metadata = {
k: v for k, v in msg.response_metadata.items() if k != "id"

View File

@ -1,5 +1,6 @@
"""Test Responses API usage."""
import json
import os
from typing import Any, Optional, cast
@ -10,9 +11,13 @@ from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
)
from pydantic import BaseModel
from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI
MODEL_NAME = "gpt-4o-mini"
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
@ -48,7 +53,7 @@ def _check_response(response: Optional[BaseMessage]) -> None:
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
llm = ChatOpenAI(model=MODEL_NAME)
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -94,7 +99,7 @@ def test_web_search() -> None:
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
llm = ChatOpenAI(model=MODEL_NAME)
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
@ -119,7 +124,7 @@ def test_function_calling() -> None:
"""return x * y"""
return x * y
llm = ChatOpenAI(model="gpt-4o-mini")
llm = ChatOpenAI(model=MODEL_NAME)
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4"))
assert len(ai_msg.tool_calls) == 1
@ -138,8 +143,110 @@ def test_function_calling() -> None:
_check_response(response)
class Foo(BaseModel):
response: str
class FooDict(TypedDict):
response: str
def test_parsed_pydantic_schema() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text()))
assert parsed == response.additional_kwargs["parsed"]
assert parsed.response
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("how are ya", response_format=Foo):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
parsed = Foo(**json.loads(full.text()))
assert parsed == full.additional_kwargs["parsed"]
assert parsed.response
async def test_parsed_pydantic_schema_async() -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = await llm.ainvoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text()))
assert parsed == response.additional_kwargs["parsed"]
assert parsed.response
# Test stream
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("how are ya", response_format=Foo):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
parsed = Foo(**json.loads(full.text()))
assert parsed == full.additional_kwargs["parsed"]
assert parsed.response
@pytest.mark.parametrize("schema", [Foo.model_json_schema(), FooDict])
def test_parsed_dict_schema(schema: Any) -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are ya", response_format=schema)
parsed = json.loads(response.text())
assert parsed == response.additional_kwargs["parsed"]
assert parsed["response"] and isinstance(parsed["response"], str)
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("how are ya", response_format=schema):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
parsed = json.loads(full.text())
assert parsed == full.additional_kwargs["parsed"]
assert parsed["response"] and isinstance(parsed["response"], str)
@pytest.mark.parametrize("schema", [Foo.model_json_schema(), FooDict])
async def test_parsed_dict_schema_async(schema: Any) -> None:
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = await llm.ainvoke("how are ya", response_format=schema)
parsed = json.loads(response.text())
assert parsed == response.additional_kwargs["parsed"]
assert parsed["response"] and isinstance(parsed["response"], str)
# Test stream
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("how are ya", response_format=schema):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
parsed = json.loads(full.text())
assert parsed == full.additional_kwargs["parsed"]
assert parsed["response"] and isinstance(parsed["response"], str)
def test_function_calling_and_structured_output() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
llm = ChatOpenAI(model=MODEL_NAME)
bound_llm = llm.bind_tools([multiply], response_format=Foo, strict=True)
# Test structured output
response = llm.invoke("how are ya", response_format=Foo)
parsed = Foo(**json.loads(response.text()))
assert parsed == response.additional_kwargs["parsed"]
assert parsed.response
# Test function calling
ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4"))
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "multiply"
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
def test_stateful_api() -> None:
llm = ChatOpenAI(model="gpt-4o-mini", use_responses_api=True)
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
response = llm.invoke("how are you, my name is Bobo")
assert "id" in response.response_metadata
@ -152,7 +259,7 @@ def test_stateful_api() -> None:
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")
llm = ChatOpenAI(model=MODEL_NAME)
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],

View File

@ -0,0 +1,23 @@
"""Standard LangChain interface tests for Responses API"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI
from tests.integration_tests.chat_models.test_base_standard import TestOpenAIStandard
class TestOpenAIResponses(TestOpenAIStandard):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI
@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o-mini", "use_responses_api": True}
@pytest.mark.xfail(reason="Unsupported.")
def test_stop_sequence(self, model: BaseChatModel) -> None:
super().test_stop_sequence(model)

View File

@ -0,0 +1,31 @@
# serializer version: 1
# name: TestOpenAIResponses.test_serdes[serialized]
dict({
'id': list([
'langchain',
'chat_models',
'openai',
'ChatOpenAI',
]),
'kwargs': dict({
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',
]),
'lc': 1,
'type': 'secret',
}),
'request_timeout': 60.0,
'stop': list([
]),
'temperature': 0.0,
'use_responses_api': True,
}),
'lc': 1,
'name': 'ChatOpenAI',
'type': 'constructor',
})
# ---

View File

@ -1569,23 +1569,6 @@ def test__construct_responses_api_input_ai_message_with_tool_calls_and_content()
assert result[1]["id"] == "func_456"
def test__construct_responses_api_input_missing_function_call_ids() -> None:
"""Test AI messages with tool calls but missing function call IDs raise an error."""
tool_calls = [
{
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
"type": "tool_call",
}
]
ai_message = AIMessage(content="", tool_calls=tool_calls)
with pytest.raises(ValueError):
_construct_responses_api_input([ai_message])
def test__construct_responses_api_input_tool_message_conversion() -> None:
"""Test that tool messages are properly converted to function_call_output."""
messages = [

View File

@ -0,0 +1,36 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ChatModelUnitTests
from langchain_openai import ChatOpenAI
class TestOpenAIResponses(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI
@property
def chat_model_params(self) -> dict:
return {"use_responses_api": True}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",
"OPENAI_ORG_ID": "org_id",
"OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com",
},
{},
{
"openai_api_key": "api_key",
"openai_organization": "org_id",
"openai_api_base": "api_base",
"openai_proxy": "https://proxy.com",
},
)

View File

@ -416,7 +416,7 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model.invoke("Hello")
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
async def test_ainvoke(self, model: BaseChatModel) -> None:
@ -448,7 +448,7 @@ class ChatModelIntegrationTests(ChatModelTests):
result = await model.ainvoke("Hello")
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
def test_stream(self, model: BaseChatModel) -> None:
@ -542,7 +542,7 @@ class ChatModelIntegrationTests(ChatModelTests):
for result in batch_results:
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
async def test_abatch(self, model: BaseChatModel) -> None:
@ -571,7 +571,7 @@ class ChatModelIntegrationTests(ChatModelTests):
for result in batch_results:
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
def test_conversation(self, model: BaseChatModel) -> None:
@ -600,7 +600,7 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model.invoke(messages)
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
@ -638,7 +638,7 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model.invoke(messages)
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
def test_usage_metadata(self, model: BaseChatModel) -> None:
@ -2136,7 +2136,7 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model.invoke([HumanMessage("hello", name="example_user")])
assert result is not None
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert isinstance(result.text(), str)
assert len(result.content) > 0
def test_agent_loop(self, model: BaseChatModel) -> None: