openai[patch]: support structured output via Responses API (#30265)

Also runs all standard tests using Responses API.
This commit is contained in:
ccurme
2025-03-14 15:14:23 -04:00
committed by GitHub
parent f54f14b747
commit c74e7b997d
7 changed files with 308 additions and 50 deletions

View File

@@ -751,11 +751,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = self.root_client.responses.create(**payload)
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
chunk, schema=original_schema_obj
):
if run_manager:
run_manager.on_llm_new_token(
@@ -773,11 +774,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = await self.root_async_client.responses.create(**payload)
original_schema_obj = kwargs.get("response_format")
async with context_manager as response:
async for chunk in response:
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk
chunk, schema=original_schema_obj
):
if run_manager:
await run_manager.on_llm_new_token(
@@ -880,8 +882,14 @@ class BaseChatOpenAI(BaseChatModel):
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = self.root_client.responses.parse(**payload)
else:
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
@@ -1062,8 +1070,15 @@ class BaseChatOpenAI(BaseChatModel):
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = await self.root_async_client.responses.parse(**payload)
else:
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
)
else:
response = await self.async_client.create(**payload)
return await run_in_executor(
@@ -2833,23 +2848,45 @@ def _construct_responses_api_payload(
if tool_choice := payload.pop("tool_choice", None):
# chat api: {"type": "function", "function": {"name": "..."}}
# responses api: {"type": "function", "name": "..."}
if tool_choice["type"] == "function" and "function" in tool_choice:
if (
isinstance(tool_choice, dict)
and tool_choice["type"] == "function"
and "function" in tool_choice
):
payload["tool_choice"] = {"type": "function", **tool_choice["function"]}
else:
payload["tool_choice"] = tool_choice
if response_format := payload.pop("response_format", None):
# Structured output
if schema := payload.pop("response_format", None):
if payload.get("text"):
text = payload["text"]
raise ValueError(
"Can specify at most one of 'response_format' or 'text', received both:"
f"\n{response_format=}\n{text=}"
f"\n{schema=}\n{text=}"
)
# chat api: {"type": "json_schema, "json_schema": {"schema": {...}, "name": "...", "description": "...", "strict": ...}} # noqa: E501
# responses api: {"type": "json_schema, "schema": {...}, "name": "...", "description": "...", "strict": ...} # noqa: E501
if response_format["type"] == "json_schema":
payload["text"] = {"type": "json_schema", **response_format["json_schema"]}
# For pydantic + non-streaming case, we use responses.parse.
# Otherwise, we use responses.create.
if not payload.get("stream") and _is_pydantic_class(schema):
payload["text_format"] = schema
else:
payload["text"] = response_format
if _is_pydantic_class(schema):
schema_dict = schema.model_json_schema()
else:
schema_dict = schema
if schema_dict == {"type": "json_object"}: # JSON mode
payload["text"] = {"format": {"type": "json_object"}}
elif (
(response_format := _convert_to_openai_response_format(schema_dict))
and (isinstance(response_format, dict))
and (response_format["type"] == "json_schema")
):
payload["text"] = {
"format": {"type": "json_schema", **response_format["json_schema"]}
}
else:
pass
return payload
@@ -2857,6 +2894,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
# "name" parameter unsupported
if "name" in msg:
msg.pop("name")
if msg["role"] == "tool":
tool_output = msg["content"]
if not isinstance(tool_output, str):
@@ -2872,17 +2912,20 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
if tool_calls := msg.pop("tool_calls", None):
# TODO: should you be able to preserve the function call object id on
# the langchain tool calls themselves?
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
raise ValueError("")
function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY]
function_call_ids = lc_msg.additional_kwargs.get(
_FUNCTION_CALL_IDS_MAP_KEY
)
for tool_call in tool_calls:
function_call = {
"type": "function_call",
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
"call_id": tool_call["id"],
"id": function_call_ids[tool_call["id"]],
}
if function_call_ids is not None and (
_id := function_call_ids.get(tool_call["id"])
):
function_call["id"] = _id
function_calls.append(function_call)
msg["content"] = msg.get("content") or []
@@ -2949,7 +2992,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
return input_
def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
def _construct_lc_result_from_responses_api(
response: Response, schema: Optional[Type[_BM]] = None
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
raise ValueError(response.error)
@@ -2994,6 +3039,8 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
],
}
content_blocks.append(block)
if hasattr(content, "parsed"):
additional_kwargs["parsed"] = content.parsed
if content.type == "refusal":
additional_kwargs["refusal"] = content.refusal
msg_id = output.id
@@ -3034,6 +3081,35 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
additional_kwargs["tool_outputs"].append(tool_output)
else:
additional_kwargs["tool_outputs"] = [tool_output]
# Workaround for parsing structured output in the streaming case.
# from openai import OpenAI
# from pydantic import BaseModel
# class Foo(BaseModel):
# response: str
# client = OpenAI()
# client.responses.parse(
# model="gpt-4o-mini",
# input=[{"content": "how are ya", "role": "user"}],
# text_format=Foo,
# stream=True, # <-- errors
# )
if (
schema is not None
and "parsed" not in additional_kwargs
and response.text
and (text_config := response.text.model_dump())
and (format_ := text_config.get("format", {}))
and (format_.get("type") == "json_schema")
):
parsed_dict = json.loads(response.output_text)
if schema and _is_pydantic_class(schema):
parsed = schema(**parsed_dict)
else:
parsed = parsed_dict
additional_kwargs["parsed"] = parsed
message = AIMessage(
content=content_blocks,
id=msg_id,
@@ -3047,7 +3123,7 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
def _convert_responses_chunk_to_generation_chunk(
chunk: Any,
chunk: Any, schema: Optional[Type[_BM]] = None
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []
@@ -3074,11 +3150,13 @@ def _convert_responses_chunk_to_generation_chunk(
msg = cast(
AIMessage,
(
_construct_lc_result_from_responses_api(chunk.response)
_construct_lc_result_from_responses_api(chunk.response, schema=schema)
.generations[0]
.message
),
)
if parsed := msg.additional_kwargs.get("parsed"):
additional_kwargs["parsed"] = parsed
usage_metadata = msg.usage_metadata
response_metadata = {
k: v for k, v in msg.response_metadata.items() if k != "id"