unit tests

This commit is contained in:
Bagatur 2025-03-11 21:06:24 -07:00
parent 1d10d0d66f
commit 2c7bd81fef
2 changed files with 211 additions and 5 deletions

View File

@ -2760,7 +2760,7 @@ def _use_response_api(payload: dict) -> bool:
)
def _construct_response_api_input(messages: list[BaseMessage]) -> list:
def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list:
input_ = []
for lc_msg in messages:
msg = _convert_message_to_dict(lc_msg)
@ -2772,6 +2772,8 @@ def _construct_response_api_input(messages: list[BaseMessage]) -> list:
}
input_.append(function_call_output)
elif msg["role"] == "assistant":
if msg.get("content"):
input_.append(msg)
if tool_calls := msg.pop("tool_calls", None):
if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY):
raise ValueError(...)
@ -2779,14 +2781,12 @@ def _construct_response_api_input(messages: list[BaseMessage]) -> list:
for tool_call in tool_calls:
function_call = {
"type": "function_call",
"name": tool_call["name"],
"arguments": tool_call["arguments"],
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
"call_id": tool_call["id"],
"id": function_call_ids[tool_call["id"]],
}
input_.append(function_call)
if msg.get("content"):
input_.append(msg)
elif msg["role"] == "user":
if isinstance(msg["content"], list):
for block in msg["content"]:

View File

@ -42,6 +42,7 @@ from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_FUNCTION_CALL_IDS_MAP_KEY,
_construct_lc_result_from_response_api,
_construct_response_api_input,
_convert_dict_to_message,
_convert_message_to_dict,
_convert_to_openai_response_format,
@ -1446,3 +1447,208 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None
assert file_search["id"] == "filesearch_123"
assert file_search["queries"] == ["python code"]
assert file_search["results"][0]["filename"] == "example.py"
def test__construct_response_api_input_human_message_with_text_blocks_conversion() -> (
None
):
"""Test that human messages with text blocks are properly converted."""
messages: list = [
HumanMessage(content=[{"type": "text", "text": "What's in this image?"}])
]
result = _construct_response_api_input(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
assert isinstance(result[0]["content"], list)
assert len(result[0]["content"]) == 1
assert result[0]["content"][0]["type"] == "input_text"
assert result[0]["content"][0]["text"] == "What's in this image?"
def test__construct_response_api_input_human_message_with_image_url_conversion() -> (
None
):
"""Test that human messages with image_url blocks are properly converted."""
messages: list = [
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.jpg",
"detail": "high",
},
},
]
)
]
result = _construct_response_api_input(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
assert isinstance(result[0]["content"], list)
assert len(result[0]["content"]) == 2
# Check text block conversion
assert result[0]["content"][0]["type"] == "input_text"
assert result[0]["content"][0]["text"] == "What's in this image?"
# Check image block conversion
assert result[0]["content"][1]["type"] == "input_image"
assert result[0]["content"][1]["image_url"] == "https://example.com/image.jpg"
assert result[0]["content"][1]["detail"] == "high"
def test__construct_response_api_input_ai_message_with_tool_calls() -> None:
"""Test that AI messages with tool calls are properly converted."""
tool_calls = [
{
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
"type": "tool_call",
}
]
# Create a mapping from tool call IDs to function call IDs
function_call_ids = {"call_123": "func_456"}
ai_message = AIMessage(
content="",
tool_calls=tool_calls,
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
)
result = _construct_response_api_input([ai_message])
assert len(result) == 1
assert result[0]["type"] == "function_call"
assert result[0]["name"] == "get_weather"
assert result[0]["arguments"] == '{"location": "San Francisco"}'
assert result[0]["call_id"] == "call_123"
assert result[0]["id"] == "func_456"
def test__construct_response_api_input_ai_message_with_tool_calls_and_content() -> None:
"""Test that AI messages with both tool calls and content are properly converted."""
tool_calls = [
{
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
"type": "tool_call",
}
]
# Create a mapping from tool call IDs to function call IDs
function_call_ids = {"call_123": "func_456"}
ai_message = AIMessage(
content="I'll check the weather for you.",
tool_calls=tool_calls,
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids},
)
result = _construct_response_api_input([ai_message])
assert len(result) == 2
# Check content
assert result[0]["role"] == "assistant"
assert result[0]["content"] == "I'll check the weather for you."
# Check function call
assert result[1]["type"] == "function_call"
assert result[1]["name"] == "get_weather"
assert result[1]["arguments"] == '{"location": "San Francisco"}'
assert result[1]["call_id"] == "call_123"
assert result[1]["id"] == "func_456"
def test__construct_response_api_input_missing_function_call_ids() -> None:
"""Test AI messages with tool calls but missing function call IDs raise an error."""
tool_calls = [
{
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
"type": "tool_call",
}
]
ai_message = AIMessage(content="", tool_calls=tool_calls)
with pytest.raises(ValueError):
_construct_response_api_input([ai_message])
def test__construct_response_api_input_tool_message_conversion() -> None:
"""Test that tool messages are properly converted to function_call_output."""
messages = [
ToolMessage(
content='{"temperature": 72, "conditions": "sunny"}',
tool_call_id="call_123",
)
]
result = _construct_response_api_input(messages)
assert len(result) == 1
assert result[0]["type"] == "function_call_output"
assert result[0]["output"] == '{"temperature": 72, "conditions": "sunny"}'
assert result[0]["call_id"] == "call_123"
def test__construct_response_api_input_multiple_message_types() -> None:
"""Test conversion of a conversation with multiple message types."""
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What's the weather in San Francisco?"),
AIMessage(
content="",
tool_calls=[
{
"type": "tool_call",
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
],
additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: {"call_123": "func_456"}},
),
ToolMessage(
content='{"temperature": 72, "conditions": "sunny"}',
tool_call_id="call_123",
),
AIMessage(content="The weather in San Francisco is 72°F and sunny."),
]
result = _construct_response_api_input(messages)
assert len(result) == 5
# Check system message
assert result[0]["role"] == "system"
assert result[0]["content"] == "You are a helpful assistant."
# Check human message
assert result[1]["role"] == "user"
assert result[1]["content"] == "What's the weather in San Francisco?"
# Check function call
assert result[2]["type"] == "function_call"
assert result[2]["name"] == "get_weather"
assert result[2]["arguments"] == '{"location": "San Francisco"}'
assert result[2]["call_id"] == "call_123"
assert result[2]["id"] == "func_456"
# Check function call output
assert result[3]["type"] == "function_call_output"
assert result[3]["output"] == '{"temperature": 72, "conditions": "sunny"}'
assert result[3]["call_id"] == "call_123"
# Check final AI message
assert result[4]["role"] == "assistant"
assert result[4]["content"] == "The weather in San Francisco is 72°F and sunny."