mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
openai[patch]: support Responses API (#30231)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -19,13 +19,30 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from openai.types.responses import ResponseOutputMessage
|
||||
from openai.types.responses.response import IncompleteDetails, Response, ResponseUsage
|
||||
from openai.types.responses.response_error import ResponseError
|
||||
from openai.types.responses.response_file_search_tool_call import (
|
||||
ResponseFileSearchToolCall,
|
||||
Result,
|
||||
)
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
from openai.types.responses.response_output_refusal import ResponseOutputRefusal
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_usage import OutputTokensDetails
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import (
|
||||
_FUNCTION_CALL_IDS_MAP_KEY,
|
||||
_construct_lc_result_from_responses_api,
|
||||
_construct_responses_api_input,
|
||||
_convert_dict_to_message,
|
||||
_convert_message_to_dict,
|
||||
_convert_to_openai_response_format,
|
||||
@@ -862,7 +879,7 @@ def test_nested_structured_output_strict() -> None:
|
||||
|
||||
setup: str
|
||||
punchline: str
|
||||
self_evaluation: SelfEvaluation
|
||||
_evaluation: SelfEvaluation
|
||||
|
||||
llm.with_structured_output(JokeWithEvaluation, method="json_schema")
|
||||
|
||||
@@ -936,3 +953,731 @@ def test_structured_outputs_parser() -> None:
|
||||
assert isinstance(deserialized, ChatGeneration)
|
||||
result = output_parser.invoke(deserialized.message)
|
||||
assert result == parsed_response
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_error_handling() -> None:
|
||||
"""Test that errors in the response are properly raised."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
error=ResponseError(message="Test error", code="server_error"),
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
_construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert "Test error" in str(excinfo.value)
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_basic_text_response() -> None:
|
||||
"""Test a basic text response with no tools or special features."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text", text="Hello, world!", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
usage=ResponseUsage(
|
||||
input_tokens=10,
|
||||
output_tokens=3,
|
||||
total_tokens=13,
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||
),
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert isinstance(result, ChatResult)
|
||||
assert len(result.generations) == 1
|
||||
assert isinstance(result.generations[0], ChatGeneration)
|
||||
assert isinstance(result.generations[0].message, AIMessage)
|
||||
assert result.generations[0].message.content == [
|
||||
{"type": "text", "text": "Hello, world!", "annotations": []}
|
||||
]
|
||||
assert result.generations[0].message.id == "msg_123"
|
||||
assert result.generations[0].message.usage_metadata
|
||||
assert result.generations[0].message.usage_metadata["input_tokens"] == 10
|
||||
assert result.generations[0].message.usage_metadata["output_tokens"] == 3
|
||||
assert result.generations[0].message.usage_metadata["total_tokens"] == 13
|
||||
assert result.generations[0].message.response_metadata["id"] == "resp_123"
|
||||
assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o"
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None:
|
||||
"""Test a response with multiple text blocks."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text", text="First part", annotations=[]
|
||||
),
|
||||
ResponseOutputText(
|
||||
type="output_text", text="Second part", annotations=[]
|
||||
),
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert len(result.generations[0].message.content) == 2
|
||||
assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore
|
||||
assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_refusal_response() -> None:
|
||||
"""Test a response with a refusal."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputRefusal(
|
||||
type="refusal", refusal="I cannot assist with that request."
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert result.generations[0].message.content == []
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["refusal"]
|
||||
== "I cannot assist with that request."
|
||||
)
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None:
|
||||
"""Test a response with a valid function call."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id="func_123",
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York", "unit": "celsius"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
msg: AIMessage = cast(AIMessage, result.generations[0].message)
|
||||
assert len(msg.tool_calls) == 1
|
||||
assert msg.tool_calls[0]["type"] == "tool_call"
|
||||
assert msg.tool_calls[0]["name"] == "get_weather"
|
||||
assert msg.tool_calls[0]["id"] == "call_123"
|
||||
assert msg.tool_calls[0]["args"] == {"location": "New York", "unit": "celsius"}
|
||||
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][
|
||||
"call_123"
|
||||
]
|
||||
== "func_123"
|
||||
)
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None:
|
||||
"""Test a response with an invalid JSON function call."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id="func_123",
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York", "unit": "celsius"',
|
||||
# Missing closing brace
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
msg: AIMessage = cast(AIMessage, result.generations[0].message)
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
assert msg.invalid_tool_calls[0]["type"] == "invalid_tool_call"
|
||||
assert msg.invalid_tool_calls[0]["name"] == "get_weather"
|
||||
assert msg.invalid_tool_calls[0]["id"] == "call_123"
|
||||
assert (
|
||||
msg.invalid_tool_calls[0]["args"]
|
||||
== '{"location": "New York", "unit": "celsius"'
|
||||
)
|
||||
assert "error" in msg.invalid_tool_calls[0]
|
||||
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_complex_response() -> None:
|
||||
"""Test a complex response with multiple output types."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text",
|
||||
text="Here's the information you requested:",
|
||||
annotations=[],
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
),
|
||||
ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id="func_123",
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York"}',
|
||||
),
|
||||
],
|
||||
metadata=dict(key1="value1", key2="value2"),
|
||||
incomplete_details=IncompleteDetails(reason="max_output_tokens"),
|
||||
status="completed",
|
||||
user="user_123",
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
# Check message content
|
||||
assert result.generations[0].message.content == [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here's the information you requested:",
|
||||
"annotations": [],
|
||||
}
|
||||
]
|
||||
|
||||
# Check tool calls
|
||||
msg: AIMessage = cast(AIMessage, result.generations[0].message)
|
||||
assert len(msg.tool_calls) == 1
|
||||
assert msg.tool_calls[0]["name"] == "get_weather"
|
||||
|
||||
# Check metadata
|
||||
assert result.generations[0].message.response_metadata["id"] == "resp_123"
|
||||
assert result.generations[0].message.response_metadata["metadata"] == {
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
assert result.generations[0].message.response_metadata["incomplete_details"] == {
|
||||
"reason": "max_output_tokens"
|
||||
}
|
||||
assert result.generations[0].message.response_metadata["status"] == "completed"
|
||||
assert result.generations[0].message.response_metadata["user"] == "user_123"
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None:
|
||||
"""Test a response without usage metadata."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text", text="Hello, world!", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
# No usage field
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert cast(AIMessage, result.generations[0].message).usage_metadata is None
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_web_search_response() -> None:
|
||||
"""Test a response with web search output."""
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ResponseFunctionWebSearch,
|
||||
)
|
||||
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseFunctionWebSearch(
|
||||
id="websearch_123", type="web_search_call", status="completed"
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert "tool_outputs" in result.generations[0].message.additional_kwargs
|
||||
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["type"]
|
||||
== "web_search_call"
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["id"]
|
||||
== "websearch_123"
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["status"]
|
||||
== "completed"
|
||||
)
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_file_search_response() -> None:
|
||||
"""Test a response with file search output."""
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseFileSearchToolCall(
|
||||
id="filesearch_123",
|
||||
type="file_search_call",
|
||||
status="completed",
|
||||
queries=["python code", "langchain"],
|
||||
results=[
|
||||
Result(
|
||||
file_id="file_123",
|
||||
filename="example.py",
|
||||
score=0.95,
|
||||
text="def hello_world() -> None:\n print('Hello, world!')",
|
||||
attributes={"language": "python", "size": 42},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
assert "tool_outputs" in result.generations[0].message.additional_kwargs
|
||||
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["type"]
|
||||
== "file_search_call"
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["id"]
|
||||
== "filesearch_123"
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["status"]
|
||||
== "completed"
|
||||
)
|
||||
assert result.generations[0].message.additional_kwargs["tool_outputs"][0][
|
||||
"queries"
|
||||
] == ["python code", "langchain"]
|
||||
assert (
|
||||
len(
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0][
|
||||
"results"
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["results"][
|
||||
0
|
||||
]["file_id"]
|
||||
== "file_123"
|
||||
)
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs["tool_outputs"][0]["results"][
|
||||
0
|
||||
]["score"]
|
||||
== 0.95
|
||||
)
|
||||
|
||||
|
||||
def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None:
|
||||
"""Test a response with both web search and file search outputs."""
|
||||
|
||||
response = Response(
|
||||
id="resp_123",
|
||||
created_at=1234567890,
|
||||
model="gpt-4o",
|
||||
object="response",
|
||||
parallel_tool_calls=True,
|
||||
tools=[],
|
||||
tool_choice="auto",
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text", text="Here's what I found:", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
),
|
||||
ResponseFunctionWebSearch(
|
||||
id="websearch_123", type="web_search_call", status="completed"
|
||||
),
|
||||
ResponseFileSearchToolCall(
|
||||
id="filesearch_123",
|
||||
type="file_search_call",
|
||||
status="completed",
|
||||
queries=["python code"],
|
||||
results=[
|
||||
Result(
|
||||
file_id="file_123",
|
||||
filename="example.py",
|
||||
score=0.95,
|
||||
text="def hello_world() -> None:\n print('Hello, world!')",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = _construct_lc_result_from_responses_api(response)
|
||||
|
||||
# Check message content
|
||||
assert result.generations[0].message.content == [
|
||||
{"type": "text", "text": "Here's what I found:", "annotations": []}
|
||||
]
|
||||
|
||||
# Check tool outputs
|
||||
assert "tool_outputs" in result.generations[0].message.additional_kwargs
|
||||
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 2
|
||||
|
||||
# Check web search output
|
||||
web_search = next(
|
||||
output
|
||||
for output in result.generations[0].message.additional_kwargs["tool_outputs"]
|
||||
if output["type"] == "web_search_call"
|
||||
)
|
||||
assert web_search["id"] == "websearch_123"
|
||||
assert web_search["status"] == "completed"
|
||||
|
||||
# Check file search output
|
||||
file_search = next(
|
||||
output
|
||||
for output in result.generations[0].message.additional_kwargs["tool_outputs"]
|
||||
if output["type"] == "file_search_call"
|
||||
)
|
||||
assert file_search["id"] == "filesearch_123"
|
||||
assert file_search["queries"] == ["python code"]
|
||||
assert file_search["results"][0]["filename"] == "example.py"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> (
|
||||
None
|
||||
):
|
||||
"""Test that human messages with text blocks are properly converted."""
|
||||
messages: list = [
|
||||
HumanMessage(content=[{"type": "text", "text": "What's in this image?"}])
|
||||
]
|
||||
result = _construct_responses_api_input(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0]["type"] == "input_text"
|
||||
assert result[0]["content"][0]["text"] == "What's in this image?"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_human_message_with_image_url_conversion() -> (
|
||||
None
|
||||
):
|
||||
"""Test that human messages with image_url blocks are properly converted."""
|
||||
messages: list = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
result = _construct_responses_api_input(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert len(result[0]["content"]) == 2
|
||||
|
||||
# Check text block conversion
|
||||
assert result[0]["content"][0]["type"] == "input_text"
|
||||
assert result[0]["content"][0]["text"] == "What's in this image?"
|
||||
|
||||
# Check image block conversion
|
||||
assert result[0]["content"][1]["type"] == "input_image"
|
||||
assert result[0]["content"][1]["image_url"] == "https://example.com/image.jpg"
|
||||
assert result[0]["content"][1]["detail"] == "high"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_ai_message_with_tool_calls() -> None:
|
||||
"""Test that AI messages with tool calls are properly converted."""
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
"type": "tool_call",
|
||||
}
|
||||
]
|
||||
|
||||
# Create a mapping from tool call IDs to function call IDs
|
||||
function_call_ids = {"call_123": "func_456"}
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=tool_calls,
|
||||
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
|
||||
)
|
||||
|
||||
result = _construct_responses_api_input([ai_message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function_call"
|
||||
assert result[0]["name"] == "get_weather"
|
||||
assert result[0]["arguments"] == '{"location": "San Francisco"}'
|
||||
assert result[0]["call_id"] == "call_123"
|
||||
assert result[0]["id"] == "func_456"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_ai_message_with_tool_calls_and_content() -> (
|
||||
None
|
||||
):
|
||||
"""Test that AI messages with both tool calls and content are properly converted."""
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
"type": "tool_call",
|
||||
}
|
||||
]
|
||||
|
||||
# Create a mapping from tool call IDs to function call IDs
|
||||
function_call_ids = {"call_123": "func_456"}
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll check the weather for you.",
|
||||
tool_calls=tool_calls,
|
||||
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
|
||||
)
|
||||
|
||||
result = _construct_responses_api_input([ai_message])
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# Check content
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"] == "I'll check the weather for you."
|
||||
|
||||
# Check function call
|
||||
assert result[1]["type"] == "function_call"
|
||||
assert result[1]["name"] == "get_weather"
|
||||
assert result[1]["arguments"] == '{"location": "San Francisco"}'
|
||||
assert result[1]["call_id"] == "call_123"
|
||||
assert result[1]["id"] == "func_456"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_missing_function_call_ids() -> None:
|
||||
"""Test AI messages with tool calls but missing function call IDs raise an error."""
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
"type": "tool_call",
|
||||
}
|
||||
]
|
||||
|
||||
ai_message = AIMessage(content="", tool_calls=tool_calls)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_construct_responses_api_input([ai_message])
|
||||
|
||||
|
||||
def test__construct_responses_api_input_tool_message_conversion() -> None:
|
||||
"""Test that tool messages are properly converted to function_call_output."""
|
||||
messages = [
|
||||
ToolMessage(
|
||||
content='{"temperature": 72, "conditions": "sunny"}',
|
||||
tool_call_id="call_123",
|
||||
)
|
||||
]
|
||||
|
||||
result = _construct_responses_api_input(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function_call_output"
|
||||
assert result[0]["output"] == '{"temperature": 72, "conditions": "sunny"}'
|
||||
assert result[0]["call_id"] == "call_123"
|
||||
|
||||
|
||||
def test__construct_responses_api_input_multiple_message_types() -> None:
|
||||
"""Test conversion of a conversation with multiple message types."""
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="What's the weather in San Francisco?"),
|
||||
HumanMessage(
|
||||
content=[{"type": "text", "text": "What's the weather in San Francisco?"}]
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"args": {"location": "San Francisco"},
|
||||
}
|
||||
],
|
||||
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: {"call_123": "func_456"}},
|
||||
),
|
||||
ToolMessage(
|
||||
content='{"temperature": 72, "conditions": "sunny"}',
|
||||
tool_call_id="call_123",
|
||||
),
|
||||
AIMessage(content="The weather in San Francisco is 72°F and sunny."),
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The weather in San Francisco is 72°F and sunny.",
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
messages_copy = [m.copy(deep=True) for m in messages]
|
||||
|
||||
result = _construct_responses_api_input(messages)
|
||||
|
||||
assert len(result) == len(messages)
|
||||
|
||||
# Check system message
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[0]["content"] == "You are a helpful assistant."
|
||||
|
||||
# Check human message
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[1]["content"] == "What's the weather in San Francisco?"
|
||||
assert result[2]["role"] == "user"
|
||||
assert result[2]["content"] == [
|
||||
{"type": "input_text", "text": "What's the weather in San Francisco?"}
|
||||
]
|
||||
|
||||
# Check function call
|
||||
assert result[3]["type"] == "function_call"
|
||||
assert result[3]["name"] == "get_weather"
|
||||
assert result[3]["arguments"] == '{"location": "San Francisco"}'
|
||||
assert result[3]["call_id"] == "call_123"
|
||||
assert result[3]["id"] == "func_456"
|
||||
|
||||
# Check function call output
|
||||
assert result[4]["type"] == "function_call_output"
|
||||
assert result[4]["output"] == '{"temperature": 72, "conditions": "sunny"}'
|
||||
assert result[4]["call_id"] == "call_123"
|
||||
|
||||
assert result[5]["role"] == "assistant"
|
||||
assert result[5]["content"] == "The weather in San Francisco is 72°F and sunny."
|
||||
|
||||
assert result[6]["role"] == "assistant"
|
||||
assert result[6]["content"] == [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "The weather in San Francisco is 72°F and sunny.",
|
||||
"annotations": [],
|
||||
}
|
||||
]
|
||||
|
||||
# assert no mutation has occurred
|
||||
assert messages_copy == messages
|
||||
|
Reference in New Issue
Block a user