This commit is contained in:
Bagatur 2025-03-12 03:20:40 -07:00
parent e8c4787430
commit 3c25ee70c1
4 changed files with 188 additions and 152 deletions

View File

@ -914,9 +914,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = self.client.with_raw_response.create(**payload) raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload): elif _use_responses_api(payload):
response = self.root_client.responses.create(**payload) response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response) return _construct_lc_result_from_responses_api(response)
else: else:
response = self.client.create(**payload) response = self.client.create(**payload)
return self._create_chat_result(response, generation_info) return self._create_chat_result(response, generation_info)
@ -933,8 +933,8 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stop"] = stop kwargs["stop"] = stop
payload = {**self._default_params, **kwargs} payload = {**self._default_params, **kwargs}
if _use_response_api(payload): if _use_responses_api(payload):
payload = _construct_response_api_payload(messages, payload) payload = _construct_responses_api_payload(messages, payload)
else: else:
payload["messages"] = [_convert_message_to_dict(m) for m in messages] payload["messages"] = [_convert_message_to_dict(m) for m in messages]
return payload return payload
@ -1088,9 +1088,9 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = await self.async_client.with_raw_response.create(**payload) raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
elif _use_response_api(payload): elif _use_responses_api(payload):
response = await self.root_async_client.responses.create(**payload) response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_response_api(response) return _construct_lc_result_from_responses_api(response)
else: else:
response = await self.async_client.create(**payload) response = await self.async_client.create(**payload)
return await run_in_executor( return await run_in_executor(
@ -2189,7 +2189,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options.""" """Set default stream_options."""
if _use_response_api(kwargs): if _use_responses_api(kwargs):
return super()._stream_responses(*args, **kwargs) return super()._stream_responses(*args, **kwargs)
else: else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs) stream_usage = self._should_stream_usage(stream_usage, **kwargs)
@ -2207,7 +2207,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options.""" """Set default stream_options."""
if _use_response_api(kwargs): if _use_responses_api(kwargs):
async for chunk in super()._astream_responses(*args, **kwargs): async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk yield chunk
else: else:
@ -2760,16 +2760,18 @@ def _is_builtin_tool(tool: dict) -> bool:
return "type" in tool and tool["type"] != "function" return "type" in tool and tool["type"] != "function"
def _use_response_api(payload: dict) -> bool: def _use_responses_api(payload: dict) -> bool:
return "tools" in payload and any( uses_builtin_tools = "tools" in payload and any(
_is_builtin_tool(tool) for tool in payload["tools"] _is_builtin_tool(tool) for tool in payload["tools"]
) )
responses_only_args = {"previous_response_id", "text", "truncation", "include"}
return bool(uses_builtin_tools or responses_only_args.intersection(payload))
def _construct_response_api_payload( def _construct_responses_api_payload(
messages: Sequence[BaseMessage], payload: dict messages: Sequence[BaseMessage], payload: dict
) -> dict: ) -> dict:
payload["input"] = _construct_response_api_input(messages) payload["input"] = _construct_responses_api_input(messages)
if tools := payload.pop("tools", None): if tools := payload.pop("tools", None):
new_tools: list = [] new_tools: list = []
for tool in tools: for tool in tools:
@ -2803,7 +2805,7 @@ def _construct_response_api_payload(
return payload return payload
def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = [] input_ = []
for lc_msg in messages: for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg) msg = _convert_message_to_dict(lc_msg)
@ -2899,7 +2901,7 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
return input_ return input_
def _construct_lc_result_from_response_api(response: Response) -> ChatResult: def _construct_lc_result_from_responses_api(response: Response) -> ChatResult:
"""Construct ChatResponse from OpenAI Response API response.""" """Construct ChatResponse from OpenAI Response API response."""
if response.error: if response.error:
raise ValueError(response.error) raise ValueError(response.error)

View File

@ -2,7 +2,6 @@
import base64 import base64
import json import json
import os
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, AsyncIterator, List, Literal, Optional, cast from typing import Any, AsyncIterator, List, Literal, Optional, cast
@ -1229,102 +1228,3 @@ def test_structured_output_and_tools() -> None:
assert len(full.tool_calls) == 1 assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0] tool_call = full.tool_calls[0]
assert tool_call["name"] == "GenerateUsername" assert tool_call["name"] == "GenerateUsername"
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
for block in response.content:
assert isinstance(block, dict)
if block["type"] == "text":
assert isinstance(block["text"], str)
for annotation in block["annotations"]:
if annotation["type"] == "file_citation":
assert all(
key in annotation
for key in ["file_id", "filename", "index", "type"]
)
elif annotation["type"] == "web_search":
assert all(
key in annotation
for key in ["end_index", "start_index", "title", "type", "url"]
)
text_content = response.text()
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full)
llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
response_id=response.response_metadata["id"],
)
_check_response(response)
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
}
response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
_check_response(response)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)

View File

@ -0,0 +1,132 @@
"""Test Responses API usage."""
import os
from typing import Optional
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
)
from langchain_openai import ChatOpenAI
def _check_response(response: Optional[BaseMessage]) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
for block in response.content:
assert isinstance(block, dict)
if block["type"] == "text":
assert isinstance(block["text"], str)
for annotation in block["annotations"]:
if annotation["type"] == "file_citation":
assert all(
key in annotation
for key in ["file_id", "filename", "index", "type"]
)
elif annotation["type"] == "web_search":
assert all(
key in annotation
for key in ["end_index", "start_index", "title", "type", "url"]
)
text_content = response.text()
assert isinstance(text_content, str)
assert text_content
assert response.usage_metadata
assert response.usage_metadata["input_tokens"] > 0
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]
assert tool_output["type"]
def test_web_search() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
first_response = llm.invoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(first_response)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
_check_response(full)
# Use OpenAI's stateful API
response = llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
previous_response_id=first_response.response_metadata["id"],
)
_check_response(response)
# Manually pass in chat history
response = llm.invoke(
[
first_response,
{
"role": "user",
"content": [{"type": "text", "text": "what about a negative one"}],
},
],
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
# Bind tool
response = llm.bind_tools([{"type": "web_search_preview"}]).invoke(
"What was a positive news story from today?"
)
_check_response(response)
async def test_web_search_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini")
response = await llm.ainvoke(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
)
_check_response(response)
assert response.response_metadata["status"]
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"What was a positive news story from today?",
tools=[{"type": "web_search_preview"}],
):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
}
response = llm.invoke("What is deep research by OpenAI?", tools=[tool])
_check_response(response)
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)

View File

@ -41,8 +41,8 @@ from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import ( from langchain_openai.chat_models.base import (
_FUNCTION_CALL_IDS_MAP_KEY, _FUNCTION_CALL_IDS_MAP_KEY,
_construct_lc_result_from_response_api, _construct_lc_result_from_responses_api,
_construct_response_api_input, _construct_responses_api_input,
_convert_dict_to_message, _convert_dict_to_message,
_convert_message_to_dict, _convert_message_to_dict,
_convert_to_openai_response_format, _convert_to_openai_response_format,
@ -955,7 +955,7 @@ def test_structured_outputs_parser() -> None:
assert result == parsed_response assert result == parsed_response
def test__construct_lc_result_from_response_api_error_handling() -> None: def test__construct_lc_result_from_responses_api_error_handling() -> None:
"""Test that errors in the response are properly raised.""" """Test that errors in the response are properly raised."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -970,12 +970,12 @@ def test__construct_lc_result_from_response_api_error_handling() -> None:
) )
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
_construct_lc_result_from_response_api(response) _construct_lc_result_from_responses_api(response)
assert "Test error" in str(excinfo.value) assert "Test error" in str(excinfo.value)
def test__construct_lc_result_from_response_api_basic_text_response() -> None: def test__construct_lc_result_from_responses_api_basic_text_response() -> None:
"""Test a basic text response with no tools or special features.""" """Test a basic text response with no tools or special features."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1006,7 +1006,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None:
), ),
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert isinstance(result, ChatResult) assert isinstance(result, ChatResult)
assert len(result.generations) == 1 assert len(result.generations) == 1
@ -1024,7 +1024,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None:
assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o" assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o"
def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None: def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None:
"""Test a response with multiple text blocks.""" """Test a response with multiple text blocks."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1052,14 +1052,14 @@ def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None:
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert len(result.generations[0].message.content) == 2 assert len(result.generations[0].message.content) == 2
assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore
assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore
def test__construct_lc_result_from_response_api_refusal_response() -> None: def test__construct_lc_result_from_responses_api_refusal_response() -> None:
"""Test a response with a refusal.""" """Test a response with a refusal."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1084,7 +1084,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None:
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert result.generations[0].message.content == [] assert result.generations[0].message.content == []
assert ( assert (
@ -1093,7 +1093,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None:
) )
def test__construct_lc_result_from_response_api_function_call_valid_json() -> None: def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None:
"""Test a response with a valid function call.""" """Test a response with a valid function call."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1114,7 +1114,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message) msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.tool_calls) == 1 assert len(msg.tool_calls) == 1
@ -1131,7 +1131,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No
) )
def test__construct_lc_result_from_response_api_function_call_invalid_json() -> None: def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None:
"""Test a response with an invalid JSON function call.""" """Test a response with an invalid JSON function call."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1153,7 +1153,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() ->
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message) msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.invalid_tool_calls) == 1 assert len(msg.invalid_tool_calls) == 1
@ -1168,7 +1168,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() ->
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
def test__construct_lc_result_from_response_api_complex_response() -> None: def test__construct_lc_result_from_responses_api_complex_response() -> None:
"""Test a complex response with multiple output types.""" """Test a complex response with multiple output types."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1206,7 +1206,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None:
user="user_123", user="user_123",
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
# Check message content # Check message content
assert result.generations[0].message.content == [ assert result.generations[0].message.content == [
@ -1235,7 +1235,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None:
assert result.generations[0].message.response_metadata["user"] == "user_123" assert result.generations[0].message.response_metadata["user"] == "user_123"
def test__construct_lc_result_from_response_api_no_usage_metadata() -> None: def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None:
"""Test a response without usage metadata.""" """Test a response without usage metadata."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1261,12 +1261,12 @@ def test__construct_lc_result_from_response_api_no_usage_metadata() -> None:
# No usage field # No usage field
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert cast(AIMessage, result.generations[0].message).usage_metadata is None assert cast(AIMessage, result.generations[0].message).usage_metadata is None
def test__construct_lc_result_from_response_api_web_search_response() -> None: def test__construct_lc_result_from_responses_api_web_search_response() -> None:
"""Test a response with web search output.""" """Test a response with web search output."""
from openai.types.responses.response_function_web_search import ( from openai.types.responses.response_function_web_search import (
ResponseFunctionWebSearch, ResponseFunctionWebSearch,
@ -1287,7 +1287,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None:
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1 assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
@ -1305,7 +1305,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None:
) )
def test__construct_lc_result_from_response_api_file_search_response() -> None: def test__construct_lc_result_from_responses_api_file_search_response() -> None:
"""Test a response with file search output.""" """Test a response with file search output."""
response = Response( response = Response(
id="resp_123", id="resp_123",
@ -1334,7 +1334,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None:
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1 assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
@ -1375,7 +1375,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None:
) )
def test__construct_lc_result_from_response_api_mixed_search_responses() -> None: def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
"""Test a response with both web search and file search outputs.""" """Test a response with both web search and file search outputs."""
response = Response( response = Response(
@ -1418,7 +1418,7 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None
], ],
) )
result = _construct_lc_result_from_response_api(response) result = _construct_lc_result_from_responses_api(response)
# Check message content # Check message content
assert result.generations[0].message.content == [ assert result.generations[0].message.content == [
@ -1449,14 +1449,14 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None
assert file_search["results"][0]["filename"] == "example.py" assert file_search["results"][0]["filename"] == "example.py"
def test__construct_response_api_input_human_message_with_text_blocks_conversion() -> ( def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> (
None None
): ):
"""Test that human messages with text blocks are properly converted.""" """Test that human messages with text blocks are properly converted."""
messages: list = [ messages: list = [
HumanMessage(content=[{"type": "text", "text": "What's in this image?"}]) HumanMessage(content=[{"type": "text", "text": "What's in this image?"}])
] ]
result = _construct_response_api_input(messages) result = _construct_responses_api_input(messages)
assert len(result) == 1 assert len(result) == 1
assert result[0]["role"] == "user" assert result[0]["role"] == "user"
@ -1466,7 +1466,7 @@ def test__construct_response_api_input_human_message_with_text_blocks_conversion
assert result[0]["content"][0]["text"] == "What's in this image?" assert result[0]["content"][0]["text"] == "What's in this image?"
def test__construct_response_api_input_human_message_with_image_url_conversion() -> ( def test__construct_responses_api_input_human_message_with_image_url_conversion() -> (
None None
): ):
"""Test that human messages with image_url blocks are properly converted.""" """Test that human messages with image_url blocks are properly converted."""
@ -1484,7 +1484,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion()
] ]
) )
] ]
result = _construct_response_api_input(messages) result = _construct_responses_api_input(messages)
assert len(result) == 1 assert len(result) == 1
assert result[0]["role"] == "user" assert result[0]["role"] == "user"
@ -1501,7 +1501,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion()
assert result[0]["content"][1]["detail"] == "high" assert result[0]["content"][1]["detail"] == "high"
def test__construct_response_api_input_ai_message_with_tool_calls() -> None: def test__construct_responses_api_input_ai_message_with_tool_calls() -> None:
"""Test that AI messages with tool calls are properly converted.""" """Test that AI messages with tool calls are properly converted."""
tool_calls = [ tool_calls = [
{ {
@ -1521,7 +1521,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
) )
result = _construct_response_api_input([ai_message]) result = _construct_responses_api_input([ai_message])
assert len(result) == 1 assert len(result) == 1
assert result[0]["type"] == "function_call" assert result[0]["type"] == "function_call"
@ -1531,7 +1531,9 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
assert result[0]["id"] == "func_456" assert result[0]["id"] == "func_456"
def test__construct_response_api_input_ai_message_with_tool_calls_and_content() -> None: def test__construct_responses_api_input_ai_message_with_tool_calls_and_content() -> (
None
):
"""Test that AI messages with both tool calls and content are properly converted.""" """Test that AI messages with both tool calls and content are properly converted."""
tool_calls = [ tool_calls = [
{ {
@ -1551,7 +1553,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content()
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
) )
result = _construct_response_api_input([ai_message]) result = _construct_responses_api_input([ai_message])
assert len(result) == 2 assert len(result) == 2
@ -1567,7 +1569,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content()
assert result[1]["id"] == "func_456" assert result[1]["id"] == "func_456"
def test__construct_response_api_input_missing_function_call_ids() -> None: def test__construct_responses_api_input_missing_function_call_ids() -> None:
"""Test AI messages with tool calls but missing function call IDs raise an error.""" """Test AI messages with tool calls but missing function call IDs raise an error."""
tool_calls = [ tool_calls = [
{ {
@ -1581,10 +1583,10 @@ def test__construct_response_api_input_missing_function_call_ids() -> None:
ai_message = AIMessage(content="", tool_calls=tool_calls) ai_message = AIMessage(content="", tool_calls=tool_calls)
with pytest.raises(ValueError): with pytest.raises(ValueError):
_construct_response_api_input([ai_message]) _construct_responses_api_input([ai_message])
def test__construct_response_api_input_tool_message_conversion() -> None: def test__construct_responses_api_input_tool_message_conversion() -> None:
"""Test that tool messages are properly converted to function_call_output.""" """Test that tool messages are properly converted to function_call_output."""
messages = [ messages = [
ToolMessage( ToolMessage(
@ -1593,7 +1595,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None:
) )
] ]
result = _construct_response_api_input(messages) result = _construct_responses_api_input(messages)
assert len(result) == 1 assert len(result) == 1
assert result[0]["type"] == "function_call_output" assert result[0]["type"] == "function_call_output"
@ -1601,7 +1603,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None:
assert result[0]["call_id"] == "call_123" assert result[0]["call_id"] == "call_123"
def test__construct_response_api_input_multiple_message_types() -> None: def test__construct_responses_api_input_multiple_message_types() -> None:
"""Test conversion of a conversation with multiple message types.""" """Test conversion of a conversation with multiple message types."""
messages = [ messages = [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
@ -1637,7 +1639,7 @@ def test__construct_response_api_input_multiple_message_types() -> None:
] ]
messages_copy = [m.copy(deep=True) for m in messages] messages_copy = [m.copy(deep=True) for m in messages]
result = _construct_response_api_input(messages) result = _construct_responses_api_input(messages)
assert len(result) == len(messages) assert len(result) == len(messages)