This commit is contained in:
Bagatur 2025-03-12 03:20:40 -07:00
parent e8c4787430
commit 3c25ee70c1
4 changed files with 188 additions and 152 deletions

View File

@ -914,9 +914,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload):
elif _use_responses_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
return _construct_lc_result_from_responses_api(response)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
@ -933,8 +933,8 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stop"] = stop
payload = {**self._default_params, **kwargs}
if _use_response_api(payload):
payload = _construct_response_api_payload(messages, payload)
if _use_responses_api(payload):
payload = _construct_responses_api_payload(messages, payload)
else:
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload
@ -1088,9 +1088,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload):
elif _use_responses_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response)
return _construct_lc_result_from_responses_api(response)
else:
response = await self.async_client.create(**payload)
return await run_in_executor(
@ -2189,7 +2189,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
if _use_response_api(kwargs):
if _use_responses_api(kwargs):
return super()._stream_responses(*args, **kwargs)
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
@ -2207,7 +2207,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
if _use_response_api(kwargs):
if _use_responses_api(kwargs):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:
@ -2760,16 +2760,18 @@ def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function"
def _use_response_api(payload: dict) -> bool:
return "tools" in payload and any(
def _use_responses_api(payload: dict) -> bool:
uses_builtin_tools = "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"]
)
responses_only_args = {"previous_response_id", "text", "truncation", "include"}
return bool(uses_builtin_tools or responses_only_args.intersection(payload))
def _construct_response_api_payload(
def _construct_responses_api_payload(
messages: Sequence[BaseMessage], payload: dict
) -> dict:
payload["input"] = _construct_response_api_input(messages)
payload["input"] = _construct_responses_api_input(messages)
if tools := payload.pop("tools", None):
new_tools: list = []
for tool in tools:
@ -2803,7 +2805,7 @@ def _construct_response_api_payload(
return payload
def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
@ -2899,7 +2901,7 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
return input_
def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response."""
if response.error:
raise ValueError(response.error)

View File

@ -2,7 +2,6 @@
import base64
import json
import os
from pathlib import Path
from textwrap import dedent
from typing import Any, AsyncIterator, List, Literal, Optional, cast
@ -1229,102 +1228,3 @@ def test_structured_output_and_tools() -> None:
assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0]
assert tool_call["name"] == "GenerateUsername"
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
for block in response.content:
assert isinstance(block, dict)
if block["type"] == "text":
assert isinstance(block["text"], str)
for annotation in block["annotations"]:
if annotation["type"] == "file_citation":
assert all(
key in annotation
for key in ["file_id", "filename", "index", "type"]
)
elif annotation["type"] == "web_search":
assert all(
key in annotation
for key in ["end_index", "start_index", "title", "type", "url"]
)
text_content = response.text()
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full)
llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
response_id=response.response_metadata["id"],
)
_check_response(response)
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
}
response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
_check_response(response)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)

View File

@ -0,0 +1,132 @@
"""Test Responses API usage."""
import os
from typing import Optional
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
)
from langchain_openai import ChatOpenAI
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
for block in response.content:
assert isinstance(block, dict)
if block["type"] == "text":
assert isinstance(block["text"], str)
for annotation in block["annotations"]:
if annotation["type"] == "file_citation":
assert all(
key in annotation
for key in ["file_id", "filename", "index", "type"]
)
elif annotation["type"] == "web_search":
assert all(
key in annotation
for key in ["end_index", "start_index", "title", "type", "url"]
)
text_content = response.text()
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(first_response)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full)
# Use OpenAI's stateful API
response = llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
previous_response_id=first_response.response_metadata["id"],
)
_check_response(response)
# Manually pass in chat history
response = llm.invoke(
[
first_response,
{
"role": "user",
"content": [{"type": "text", "text": "what about a negative one"}],
},
],
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
# Bind tool
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
"What was a positive news story from today?"
)
_check_response(response)
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
}
response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
_check_response(response)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)

View File

@ -41,8 +41,8 @@ from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_FUNCTION_CALL_IDS_MAP_KEY,
_construct_lc_result_from_response_api,
_construct_response_api_input,
_construct_lc_result_from_responses_api,
_construct_responses_api_input,
_convert_dict_to_message,
_convert_message_to_dict,
_convert_to_openai_response_format,
@ -955,7 +955,7 @@ def test_structured_outputs_parser() -> None:
assert result == parsed_response
def test__construct_lc_result_from_response_api_error_handling() -> None:
def test__construct_lc_result_from_responses_api_error_handling() -> None:
"""Test that errors in the response are properly raised."""
response = Response(
id="resp_123",
@ -970,12 +970,12 @@ def test__construct_lc_result_from_response_api_error_handling() -> None:
)
with pytest.raises(ValueError) as excinfo:
_construct_lc_result_from_response_api(response)
_construct_lc_result_from_responses_api(response)
assert "Test error" in str(excinfo.value)
def test__construct_lc_result_from_response_api_basic_text_response() -> None:
def test__construct_lc_result_from_responses_api_basic_text_response() -> None:
"""Test a basic text response with no tools or special features."""
response = Response(
id="resp_123",
@ -1006,7 +1006,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None:
),
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert isinstance(result, ChatResult)
assert len(result.generations) == 1
@ -1024,7 +1024,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None:
assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o"
def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None:
def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None:
"""Test a response with multiple text blocks."""
response = Response(
id="resp_123",
@ -1052,14 +1052,14 @@ def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None:
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert len(result.generations[0].message.content) == 2
assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore
assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore
def test__construct_lc_result_from_response_api_refusal_response() -> None:
def test__construct_lc_result_from_responses_api_refusal_response() -> None:
"""Test a response with a refusal."""
response = Response(
id="resp_123",
@ -1084,7 +1084,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None:
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert result.generations[0].message.content == []
assert (
@ -1093,7 +1093,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None:
)
def test__construct_lc_result_from_response_api_function_call_valid_json() -> None:
def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None:
"""Test a response with a valid function call."""
response = Response(
id="resp_123",
@ -1114,7 +1114,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.tool_calls) == 1
@ -1131,7 +1131,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No
)
def test__construct_lc_result_from_response_api_function_call_invalid_json() -> None:
def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None:
"""Test a response with an invalid JSON function call."""
response = Response(
id="resp_123",
@ -1153,7 +1153,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() ->
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.invalid_tool_calls) == 1
@ -1168,7 +1168,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() ->
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
def test__construct_lc_result_from_response_api_complex_response() -> None:
def test__construct_lc_result_from_responses_api_complex_response() -> None:
"""Test a complex response with multiple output types."""
response = Response(
id="resp_123",
@ -1206,7 +1206,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None:
user="user_123",
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
# Check message content
assert result.generations[0].message.content == [
@ -1235,7 +1235,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None:
assert result.generations[0].message.response_metadata["user"] == "user_123"
def test__construct_lc_result_from_response_api_no_usage_metadata() -> None:
def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None:
"""Test a response without usage metadata."""
response = Response(
id="resp_123",
@ -1261,12 +1261,12 @@ def test__construct_lc_result_from_response_api_no_usage_metadata() -> None:
# No usage field
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert cast(AIMessage, result.generations[0].message).usage_metadata is None
def test__construct_lc_result_from_response_api_web_search_response() -> None:
def test__construct_lc_result_from_responses_api_web_search_response() -> None:
"""Test a response with web search output."""
from openai.types.responses.response_function_web_search import (
ResponseFunctionWebSearch,
@ -1287,7 +1287,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None:
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
@ -1305,7 +1305,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None:
)
def test__construct_lc_result_from_response_api_file_search_response() -> None:
def test__construct_lc_result_from_responses_api_file_search_response() -> None:
"""Test a response with file search output."""
response = Response(
id="resp_123",
@ -1334,7 +1334,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None:
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
@ -1375,7 +1375,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None:
)
def test__construct_lc_result_from_response_api_mixed_search_responses() -> None:
def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
"""Test a response with both web search and file search outputs."""
response = Response(
@ -1418,7 +1418,7 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None
],
)
result = _construct_lc_result_from_response_api(response)
result = _construct_lc_result_from_responses_api(response)
# Check message content
assert result.generations[0].message.content == [
@ -1449,14 +1449,14 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None
assert file_search["results"][0]["filename"] == "example.py"
def test__construct_response_api_input_human_message_with_text_blocks_conversion() -> (
def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> (
None
):
"""Test that human messages with text blocks are properly converted."""
messages: list = [
HumanMessage(content=[{"type": "text", "text": "What's in this image?"}])
]
result = _construct_response_api_input(messages)
result = _construct_responses_api_input(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
@ -1466,7 +1466,7 @@ def test__construct_response_api_input_human_message_with_text_blocks_conversion
assert result[0]["content"][0]["text"] == "What's in this image?"
def test__construct_response_api_input_human_message_with_image_url_conversion() -> (
def test__construct_responses_api_input_human_message_with_image_url_conversion() -> (
None
):
"""Test that human messages with image_url blocks are properly converted."""
@ -1484,7 +1484,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion()
]
)
]
result = _construct_response_api_input(messages)
result = _construct_responses_api_input(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
@ -1501,7 +1501,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion()
assert result[0]["content"][1]["detail"] == "high"
def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
def test__construct_responses_api_input_ai_message_with_tool_calls() -> None:
"""Test that AI messages with tool calls are properly converted."""
tool_calls = [
{
@ -1521,7 +1521,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
)
result = _construct_response_api_input([ai_message])
result = _construct_responses_api_input([ai_message])
assert len(result) == 1
assert result[0]["type"] == "function_call"
@ -1531,7 +1531,9 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
assert result[0]["id"] == "func_456"
def test__construct_response_api_input_ai_message_with_tool_calls_and_content() -> None:
def test__construct_responses_api_input_ai_message_with_tool_calls_and_content() -> (
None
):
"""Test that AI messages with both tool calls and content are properly converted."""
tool_calls = [
{
@ -1551,7 +1553,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content()
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
)
result = _construct_response_api_input([ai_message])
result = _construct_responses_api_input([ai_message])
assert len(result) == 2
@ -1567,7 +1569,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content()
assert result[1]["id"] == "func_456"
def test__construct_response_api_input_missing_function_call_ids() -> None:
def test__construct_responses_api_input_missing_function_call_ids() -> None:
"""Test AI messages with tool calls but missing function call IDs raise an error."""
tool_calls = [
{
@ -1581,10 +1583,10 @@ def test__construct_response_api_input_missing_function_call_ids() -> None:
ai_message = AIMessage(content="", tool_calls=tool_calls)
with pytest.raises(ValueError):
_construct_response_api_input([ai_message])
_construct_responses_api_input([ai_message])
def test__construct_response_api_input_tool_message_conversion() -> None:
def test__construct_responses_api_input_tool_message_conversion() -> None:
"""Test that tool messages are properly converted to function_call_output."""
messages = [
ToolMessage(
@ -1593,7 +1595,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None:
)
]
result = _construct_response_api_input(messages)
result = _construct_responses_api_input(messages)
assert len(result) == 1
assert result[0]["type"] == "function_call_output"
@ -1601,7 +1603,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None:
assert result[0]["call_id"] == "call_123"
def test__construct_response_api_input_multiple_message_types() -> None:
def test__construct_responses_api_input_multiple_message_types() -> None:
"""Test conversion of a conversation with multiple message types."""
messages = [
SystemMessage(content="You are a helpful assistant."),
@ -1637,7 +1639,7 @@ def test__construct_response_api_input_multiple_message_types() -> None:
]
messages_copy = [m.copy(deep=True) for m in messages]
result = _construct_response_api_input(messages)
result = _construct_responses_api_input(messages)
assert len(result) == len(messages)