From 2c7bd81fef6577d462ac64e65116c64e00c2333e Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 11 Mar 2025 21:06:24 -0700 Subject: [PATCH] unit tests --- .../langchain_openai/chat_models/base.py | 10 +- .../tests/unit_tests/chat_models/test_base.py | 206 ++++++++++++++++++ 2 files changed, 211 insertions(+), 5 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 5d6e1bd255d..55d65874b9f 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2760,7 +2760,7 @@ def _use_response_api(payload: dict) -> bool: ) -def _construct_response_api_input(messages: list[BaseMessage]) -> list: +def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: input_ = [] for lc_msg in messages: msg = _convert_message_to_dict(lc_msg) @@ -2772,6 +2772,8 @@ def _construct_response_api_input(messages: list[BaseMessage]) -> list: } input_.append(function_call_output) elif msg["role"] == "assistant": + if msg.get("content"): + input_.append(msg) if tool_calls := msg.pop("tool_calls", None): if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY): raise ValueError(...) @@ -2779,14 +2781,12 @@ def _construct_response_api_input(messages: list[BaseMessage]) -> list: for tool_call in tool_calls: function_call = { "type": "function_call", - "name": tool_call["name"], - "arguments": tool_call["arguments"], + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], "call_id": tool_call["id"], "id": function_call_ids[tool_call["id"]], } input_.append(function_call) - if msg.get("content"): - input_.append(msg) elif msg["role"] == "user": if isinstance(msg["content"], list): for block in msg["content"]: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 5184df26b8d..938c0bc8115 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -42,6 +42,7 @@ from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( _FUNCTION_CALL_IDS_MAP_KEY, _construct_lc_result_from_response_api, + _construct_response_api_input, _convert_dict_to_message, _convert_message_to_dict, _convert_to_openai_response_format, @@ -1446,3 +1447,208 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None assert file_search["id"] == "filesearch_123" assert file_search["queries"] == ["python code"] assert file_search["results"][0]["filename"] == "example.py" + + +def test__construct_response_api_input_human_message_with_text_blocks_conversion() -> ( + None +): + """Test that human messages with text blocks are properly converted.""" + messages: list = [ + HumanMessage(content=[{"type": "text", "text": "What's in this image?"}]) + ] + result = _construct_response_api_input(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert isinstance(result[0]["content"], list) + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0]["type"] == "input_text" + assert result[0]["content"][0]["text"] == "What's in this image?" + + +def test__construct_response_api_input_human_message_with_image_url_conversion() -> ( + None +): + """Test that human messages with image_url blocks are properly converted.""" + messages: list = [ + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "high", + }, + }, + ] + ) + ] + result = _construct_response_api_input(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert isinstance(result[0]["content"], list) + assert len(result[0]["content"]) == 2 + + # Check text block conversion + assert result[0]["content"][0]["type"] == "input_text" + assert result[0]["content"][0]["text"] == "What's in this image?" + + # Check image block conversion + assert result[0]["content"][1]["type"] == "input_image" + assert result[0]["content"][1]["image_url"] == "https://example.com/image.jpg" + assert result[0]["content"][1]["detail"] == "high" + + +def test__construct_response_api_input_ai_message_with_tool_calls() -> None: + """Test that AI messages with tool calls are properly converted.""" + tool_calls = [ + { + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + "type": "tool_call", + } + ] + + # Create a mapping from tool call IDs to function call IDs + function_call_ids = {"call_123": "func_456"} + + ai_message = AIMessage( + content="", + tool_calls=tool_calls, + additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, + ) + + result = _construct_response_api_input([ai_message]) + + assert len(result) == 1 + assert result[0]["type"] == "function_call" + assert result[0]["name"] == "get_weather" + assert result[0]["arguments"] == '{"location": "San Francisco"}' + assert result[0]["call_id"] == "call_123" + assert result[0]["id"] == "func_456" + + +def test__construct_response_api_input_ai_message_with_tool_calls_and_content() -> None: + """Test that AI messages with both tool calls and content are properly converted.""" + tool_calls = [ + { + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + "type": "tool_call", + } + ] + + # Create a mapping from tool call IDs to function call IDs + function_call_ids = {"call_123": "func_456"} + + ai_message = AIMessage( + content="I'll check the weather for you.", + tool_calls=tool_calls, + additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, + ) + + result = _construct_response_api_input([ai_message]) + + assert len(result) == 2 + + # Check content + assert result[0]["role"] == "assistant" + assert result[0]["content"] == "I'll check the weather for you." + + # Check function call + assert result[1]["type"] == "function_call" + assert result[1]["name"] == "get_weather" + assert result[1]["arguments"] == '{"location": "San Francisco"}' + assert result[1]["call_id"] == "call_123" + assert result[1]["id"] == "func_456" + + +def test__construct_response_api_input_missing_function_call_ids() -> None: + """Test AI messages with tool calls but missing function call IDs raise an error.""" + tool_calls = [ + { + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + "type": "tool_call", + } + ] + + ai_message = AIMessage(content="", tool_calls=tool_calls) + + with pytest.raises(ValueError): + _construct_response_api_input([ai_message]) + + +def test__construct_response_api_input_tool_message_conversion() -> None: + """Test that tool messages are properly converted to function_call_output.""" + messages = [ + ToolMessage( + content='{"temperature": 72, "conditions": "sunny"}', + tool_call_id="call_123", + ) + ] + + result = _construct_response_api_input(messages) + + assert len(result) == 1 + assert result[0]["type"] == "function_call_output" + assert result[0]["output"] == '{"temperature": 72, "conditions": "sunny"}' + assert result[0]["call_id"] == "call_123" + + +def test__construct_response_api_input_multiple_message_types() -> None: + """Test conversion of a conversation with multiple message types.""" + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="What's the weather in San Francisco?"), + AIMessage( + content="", + tool_calls=[ + { + "type": "tool_call", + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ], + additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: {"call_123": "func_456"}}, + ), + ToolMessage( + content='{"temperature": 72, "conditions": "sunny"}', + tool_call_id="call_123", + ), + AIMessage(content="The weather in San Francisco is 72°F and sunny."), + ] + + result = _construct_response_api_input(messages) + + assert len(result) == 5 + + # Check system message + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are a helpful assistant." + + # Check human message + assert result[1]["role"] == "user" + assert result[1]["content"] == "What's the weather in San Francisco?" + + # Check function call + assert result[2]["type"] == "function_call" + assert result[2]["name"] == "get_weather" + assert result[2]["arguments"] == '{"location": "San Francisco"}' + assert result[2]["call_id"] == "call_123" + assert result[2]["id"] == "func_456" + + # Check function call output + assert result[3]["type"] == "function_call_output" + assert result[3]["output"] == '{"temperature": 72, "conditions": "sunny"}' + assert result[3]["call_id"] == "call_123" + + # Check final AI message + assert result[4]["role"] == "assistant" + assert result[4]["content"] == "The weather in San Francisco is 72°F and sunny."