openai[patch]: support additional Responses API features (#30322)

- Include response headers
- Max tokens
- Reasoning effort
- Fix bug with structured output / strict
- Fix bug with simultaneous tool calling + structured output
This commit is contained in:
ccurme
2025-03-17 12:02:21 -04:00
committed by GitHub
parent d8510270ee
commit eb9b992aa6
6 changed files with 239 additions and 92 deletions

View File

@@ -750,18 +750,29 @@ class BaseChatOpenAI(BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = self.root_client.responses.create(**payload)
if self.include_response_headers:
raw_context_manager = self.root_client.with_raw_response.responses.create(
**payload
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = self.root_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
is_first_chunk = True
for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj
chunk, schema=original_schema_obj, metadata=metadata
):
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
yield generation_chunk
async def _astream_responses(
@@ -773,18 +784,31 @@ class BaseChatOpenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
context_manager = await self.root_async_client.responses.create(**payload)
if self.include_response_headers:
raw_context_manager = (
await self.root_async_client.with_raw_response.responses.create(
**payload
)
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = await self.root_async_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
metadata = headers if is_first_chunk else {}
if generation_chunk := _convert_responses_chunk_to_generation_chunk(
chunk, schema=original_schema_obj
chunk, schema=original_schema_obj, metadata=metadata
):
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
yield generation_chunk
def _stream(
@@ -877,19 +901,26 @@ class BaseChatOpenAI(BaseChatModel):
response = self.root_client.beta.chat.completions.parse(**payload)
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = self.root_client.responses.parse(**payload)
else:
response = self.root_client.responses.create(**payload)
if self.include_response_headers:
raw_response = self.root_client.with_raw_response.responses.create(
**payload
)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
response, schema=original_schema_obj, metadata=generation_info
)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
@@ -1065,20 +1096,28 @@ class BaseChatOpenAI(BaseChatModel):
)
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = await self.root_async_client.responses.parse(**payload)
else:
response = await self.root_async_client.responses.create(**payload)
if self.include_response_headers:
raw_response = (
await self.root_async_client.with_raw_response.responses.create(
**payload
)
)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(
response, schema=original_schema_obj
response, schema=original_schema_obj, metadata=generation_info
)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
return await run_in_executor(
@@ -2834,6 +2873,13 @@ def _use_responses_api(payload: dict) -> bool:
def _construct_responses_api_payload(
messages: Sequence[BaseMessage], payload: dict
) -> dict:
# Rename legacy parameters
for legacy_token_param in ["max_tokens", "max_completion_tokens"]:
if legacy_token_param in payload:
payload["max_output_tokens"] = payload.pop(legacy_token_param)
if "reasoning_effort" in payload:
payload["reasoning"] = {"effort": payload.pop("reasoning_effort")}
payload["input"] = _construct_responses_api_input(messages)
if tools := payload.pop("tools", None):
new_tools: list = []
@@ -2868,17 +2914,23 @@ def _construct_responses_api_payload(
# For pydantic + non-streaming case, we use responses.parse.
# Otherwise, we use responses.create.
strict = payload.pop("strict", None)
if not payload.get("stream") and _is_pydantic_class(schema):
payload["text_format"] = schema
else:
if _is_pydantic_class(schema):
schema_dict = schema.model_json_schema()
strict = True
else:
schema_dict = schema
if schema_dict == {"type": "json_object"}: # JSON mode
payload["text"] = {"format": {"type": "json_object"}}
elif (
(response_format := _convert_to_openai_response_format(schema_dict))
(
response_format := _convert_to_openai_response_format(
schema_dict, strict=strict
)
)
and (isinstance(response_format, dict))
and (response_format["type"] == "json_schema")
):
@@ -2993,7 +3045,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
def _construct_lc_result_from_responses_api(
response: Response, schema: Optional[Type[_BM]] = None
response: Response,
schema: Optional[Type[_BM]] = None,
metadata: Optional[dict] = None,
) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
@@ -3014,6 +3068,8 @@ def _construct_lc_result_from_responses_api(
"model",
)
}
if metadata:
response_metadata.update(metadata)
# for compatibility with chat completion calls.
response_metadata["model_name"] = response_metadata.get("model")
if response.usage:
@@ -3099,17 +3155,21 @@ def _construct_lc_result_from_responses_api(
if (
schema is not None
and "parsed" not in additional_kwargs
and response.output_text # tool calls can generate empty output text
and response.text
and (text_config := response.text.model_dump())
and (format_ := text_config.get("format", {}))
and (format_.get("type") == "json_schema")
):
parsed_dict = json.loads(response.output_text)
if schema and _is_pydantic_class(schema):
parsed = schema(**parsed_dict)
else:
parsed = parsed_dict
additional_kwargs["parsed"] = parsed
try:
parsed_dict = json.loads(response.output_text)
if schema and _is_pydantic_class(schema):
parsed = schema(**parsed_dict)
else:
parsed = parsed_dict
additional_kwargs["parsed"] = parsed
except json.JSONDecodeError:
pass
message = AIMessage(
content=content_blocks,
id=msg_id,
@@ -3123,12 +3183,15 @@ def _construct_lc_result_from_responses_api(
def _convert_responses_chunk_to_generation_chunk(
chunk: Any, schema: Optional[Type[_BM]] = None
chunk: Any, schema: Optional[Type[_BM]] = None, metadata: Optional[dict] = None
) -> Optional[ChatGenerationChunk]:
content = []
tool_call_chunks: list = []
additional_kwargs: dict = {}
response_metadata = {}
if metadata:
response_metadata = metadata
else:
response_metadata = {}
usage_metadata = None
id = None
if chunk.type == "response.output_text.delta":