From 3c25ee70c1a7a3b55b954c6858b32298e3c64f3d Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 12 Mar 2025 03:20:40 -0700 Subject: [PATCH] fmt --- .../langchain_openai/chat_models/base.py | 30 ++-- .../chat_models/test_base.py | 100 ------------- .../chat_models/test_responses_api.py | 132 ++++++++++++++++++ .../tests/unit_tests/chat_models/test_base.py | 78 ++++++----- 4 files changed, 188 insertions(+), 152 deletions(-) create mode 100644 libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 6d94be7b30f..03a4d6bd4b0 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -914,9 +914,9 @@ class BaseChatOpenAI(BaseChatModel): raw_response = self.client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} - elif _use_response_api(payload): + elif _use_responses_api(payload): response = self.root_client.responses.create(**payload) - return _construct_lc_result_from_response_api(response) + return _construct_lc_result_from_responses_api(response) else: response = self.client.create(**payload) return self._create_chat_result(response, generation_info) @@ -933,8 +933,8 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stop"] = stop payload = {**self._default_params, **kwargs} - if _use_response_api(payload): - payload = _construct_response_api_payload(messages, payload) + if _use_responses_api(payload): + payload = _construct_responses_api_payload(messages, payload) else: payload["messages"] = [_convert_message_to_dict(m) for m in messages] return payload @@ -1088,9 +1088,9 @@ class BaseChatOpenAI(BaseChatModel): raw_response = await self.async_client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} - elif _use_response_api(payload): + elif _use_responses_api(payload): response = await self.root_async_client.responses.create(**payload) - return _construct_lc_result_from_response_api(response) + return _construct_lc_result_from_responses_api(response) else: response = await self.async_client.create(**payload) return await run_in_executor( @@ -2189,7 +2189,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> Iterator[ChatGenerationChunk]: """Set default stream_options.""" - if _use_response_api(kwargs): + if _use_responses_api(kwargs): return super()._stream_responses(*args, **kwargs) else: stream_usage = self._should_stream_usage(stream_usage, **kwargs) @@ -2207,7 +2207,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: """Set default stream_options.""" - if _use_response_api(kwargs): + if _use_responses_api(kwargs): async for chunk in super()._astream_responses(*args, **kwargs): yield chunk else: @@ -2760,16 +2760,18 @@ def _is_builtin_tool(tool: dict) -> bool: return "type" in tool and tool["type"] != "function" -def _use_response_api(payload: dict) -> bool: - return "tools" in payload and any( +def _use_responses_api(payload: dict) -> bool: + uses_builtin_tools = "tools" in payload and any( _is_builtin_tool(tool) for tool in payload["tools"] ) + responses_only_args = {"previous_response_id", "text", "truncation", "include"} + return bool(uses_builtin_tools or responses_only_args.intersection(payload)) -def _construct_response_api_payload( +def _construct_responses_api_payload( messages: Sequence[BaseMessage], payload: dict ) -> dict: - payload["input"] = _construct_response_api_input(messages) + payload["input"] = _construct_responses_api_input(messages) if tools := payload.pop("tools", None): new_tools: list = [] for tool in tools: @@ -2803,7 +2805,7 @@ def _construct_response_api_payload( return payload -def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: +def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: input_ = [] for lc_msg in messages: msg = _convert_message_to_dict(lc_msg) @@ -2899,7 +2901,7 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: return input_ -def _construct_lc_result_from_response_api(response: Response) -> ChatResult: +def _construct_lc_result_from_responses_api(response: Response) -> ChatResult: """Construct ChatResponse from OpenAI Response API response.""" if response.error: raise ValueError(response.error) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 41ca3db7189..09cae79520b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -2,7 +2,6 @@ import base64 import json -import os from pathlib import Path from textwrap import dedent from typing import Any, AsyncIterator, List, Literal, Optional, cast @@ -1229,102 +1228,3 @@ def test_structured_output_and_tools() -> None: assert len(full.tool_calls) == 1 tool_call = full.tool_calls[0] assert tool_call["name"] == "GenerateUsername" - - -def _check_response(response: Optional[BaseMessage]) -> None: - assert isinstance(response, AIMessage) - assert isinstance(response.content, list) - for block in response.content: - assert isinstance(block, dict) - if block["type"] == "text": - assert isinstance(block["text"], str) - for annotation in block["annotations"]: - if annotation["type"] == "file_citation": - assert all( - key in annotation - for key in ["file_id", "filename", "index", "type"] - ) - elif annotation["type"] == "web_search": - assert all( - key in annotation - for key in ["end_index", "start_index", "title", "type", "url"] - ) - - text_content = response.text() - assert isinstance(text_content, str) - assert text_content - assert response.usage_metadata - assert response.usage_metadata["input_tokens"] > 0 - assert response.usage_metadata["output_tokens"] > 0 - assert response.usage_metadata["total_tokens"] > 0 - assert response.response_metadata["model_name"] - for tool_output in response.additional_kwargs["tool_outputs"]: - assert tool_output["id"] - assert tool_output["status"] - assert tool_output["type"] - - -def test_web_search() -> None: - llm = ChatOpenAI(model="gpt-4o-mini") - response = llm.invoke( - "What was a positive news story from today?", - tools=[{"type": "web_search_preview"}], - ) - _check_response(response) - assert response.response_metadata["status"] - - # Test streaming - full: Optional[BaseMessageChunk] = None - for chunk in llm.stream( - "What was a positive news story from today?", - tools=[{"type": "web_search_preview"}], - ): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk - _check_response(full) - - llm.invoke( - "what about a negative one", - tools=[{"type": "web_search_preview"}], - response_id=response.response_metadata["id"], - ) - _check_response(response) - - -async def test_web_search_async() -> None: - llm = ChatOpenAI(model="gpt-4o-mini") - response = await llm.ainvoke( - "What was a positive news story from today?", - tools=[{"type": "web_search_preview"}], - ) - _check_response(response) - assert response.response_metadata["status"] - - # Test streaming - full: Optional[BaseMessageChunk] = None - async for chunk in llm.astream( - "What was a positive news story from today?", - tools=[{"type": "web_search_preview"}], - ): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - _check_response(full) - - -def test_file_search() -> None: - pytest.skip() # TODO: set up infra - llm = ChatOpenAI(model="gpt-4o-mini") - tool = { - "type": "file_search", - "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], - } - response = llm.invoke("What is deep research by OpenAI?", tools=[tool]) - _check_response(response) - - full: Optional[BaseMessageChunk] = None - for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]): - assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - _check_response(full) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py new file mode 100644 index 00000000000..b7e7550f231 --- /dev/null +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -0,0 +1,132 @@ +"""Test Responses API usage.""" + +import os +from typing import Optional + +import pytest +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, +) + +from langchain_openai import ChatOpenAI + + +def _check_response(response: Optional[BaseMessage]) -> None: + assert isinstance(response, AIMessage) + assert isinstance(response.content, list) + for block in response.content: + assert isinstance(block, dict) + if block["type"] == "text": + assert isinstance(block["text"], str) + for annotation in block["annotations"]: + if annotation["type"] == "file_citation": + assert all( + key in annotation + for key in ["file_id", "filename", "index", "type"] + ) + elif annotation["type"] == "web_search": + assert all( + key in annotation + for key in ["end_index", "start_index", "title", "type", "url"] + ) + + text_content = response.text() + assert isinstance(text_content, str) + assert text_content + assert response.usage_metadata + assert response.usage_metadata["input_tokens"] > 0 + assert response.usage_metadata["output_tokens"] > 0 + assert response.usage_metadata["total_tokens"] > 0 + assert response.response_metadata["model_name"] + for tool_output in response.additional_kwargs["tool_outputs"]: + assert tool_output["id"] + assert tool_output["status"] + assert tool_output["type"] + + +def test_web_search() -> None: + llm = ChatOpenAI(model="gpt-4o-mini") + first_response = llm.invoke( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ) + _check_response(first_response) + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + _check_response(full) + + # Use OpenAI's stateful API + response = llm.invoke( + "what about a negative one", + tools=[{"type": "web_search_preview"}], + previous_response_id=first_response.response_metadata["id"], + ) + _check_response(response) + + # Manually pass in chat history + response = llm.invoke( + [ + first_response, + { + "role": "user", + "content": [{"type": "text", "text": "what about a negative one"}], + }, + ], + tools=[{"type": "web_search_preview"}], + ) + _check_response(response) + + # Bind tool + response = llm.bind_tools([{"type": "web_search_preview"}]).invoke( + "What was a positive news story from today?" + ) + _check_response(response) + + +async def test_web_search_async() -> None: + llm = ChatOpenAI(model="gpt-4o-mini") + response = await llm.ainvoke( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ) + _check_response(response) + assert response.response_metadata["status"] + + # Test streaming + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + _check_response(full) + + +def test_file_search() -> None: + pytest.skip() # TODO: set up infra + llm = ChatOpenAI(model="gpt-4o-mini") + tool = { + "type": "file_search", + "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], + } + response = llm.invoke("What is deep research by OpenAI?", tools=[tool]) + _check_response(response) + + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("What is deep research by OpenAI?", tools=[tool]): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + _check_response(full) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index fd9c96ea703..e5e89990b78 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -41,8 +41,8 @@ from typing_extensions import TypedDict from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( _FUNCTION_CALL_IDS_MAP_KEY, - _construct_lc_result_from_response_api, - _construct_response_api_input, + _construct_lc_result_from_responses_api, + _construct_responses_api_input, _convert_dict_to_message, _convert_message_to_dict, _convert_to_openai_response_format, @@ -955,7 +955,7 @@ def test_structured_outputs_parser() -> None: assert result == parsed_response -def test__construct_lc_result_from_response_api_error_handling() -> None: +def test__construct_lc_result_from_responses_api_error_handling() -> None: """Test that errors in the response are properly raised.""" response = Response( id="resp_123", @@ -970,12 +970,12 @@ def test__construct_lc_result_from_response_api_error_handling() -> None: ) with pytest.raises(ValueError) as excinfo: - _construct_lc_result_from_response_api(response) + _construct_lc_result_from_responses_api(response) assert "Test error" in str(excinfo.value) -def test__construct_lc_result_from_response_api_basic_text_response() -> None: +def test__construct_lc_result_from_responses_api_basic_text_response() -> None: """Test a basic text response with no tools or special features.""" response = Response( id="resp_123", @@ -1006,7 +1006,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None: ), ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert isinstance(result, ChatResult) assert len(result.generations) == 1 @@ -1024,7 +1024,7 @@ def test__construct_lc_result_from_response_api_basic_text_response() -> None: assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o" -def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None: +def test__construct_lc_result_from_responses_api_multiple_text_blocks() -> None: """Test a response with multiple text blocks.""" response = Response( id="resp_123", @@ -1052,14 +1052,14 @@ def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None: ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert len(result.generations[0].message.content) == 2 assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore -def test__construct_lc_result_from_response_api_refusal_response() -> None: +def test__construct_lc_result_from_responses_api_refusal_response() -> None: """Test a response with a refusal.""" response = Response( id="resp_123", @@ -1084,7 +1084,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None: ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert result.generations[0].message.content == [] assert ( @@ -1093,7 +1093,7 @@ def test__construct_lc_result_from_response_api_refusal_response() -> None: ) -def test__construct_lc_result_from_response_api_function_call_valid_json() -> None: +def test__construct_lc_result_from_responses_api_function_call_valid_json() -> None: """Test a response with a valid function call.""" response = Response( id="resp_123", @@ -1114,7 +1114,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) msg: AIMessage = cast(AIMessage, result.generations[0].message) assert len(msg.tool_calls) == 1 @@ -1131,7 +1131,7 @@ def test__construct_lc_result_from_response_api_function_call_valid_json() -> No ) -def test__construct_lc_result_from_response_api_function_call_invalid_json() -> None: +def test__construct_lc_result_from_responses_api_function_call_invalid_json() -> None: """Test a response with an invalid JSON function call.""" response = Response( id="resp_123", @@ -1153,7 +1153,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() -> ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) msg: AIMessage = cast(AIMessage, result.generations[0].message) assert len(msg.invalid_tool_calls) == 1 @@ -1168,7 +1168,7 @@ def test__construct_lc_result_from_response_api_function_call_invalid_json() -> assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs -def test__construct_lc_result_from_response_api_complex_response() -> None: +def test__construct_lc_result_from_responses_api_complex_response() -> None: """Test a complex response with multiple output types.""" response = Response( id="resp_123", @@ -1206,7 +1206,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None: user="user_123", ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) # Check message content assert result.generations[0].message.content == [ @@ -1235,7 +1235,7 @@ def test__construct_lc_result_from_response_api_complex_response() -> None: assert result.generations[0].message.response_metadata["user"] == "user_123" -def test__construct_lc_result_from_response_api_no_usage_metadata() -> None: +def test__construct_lc_result_from_responses_api_no_usage_metadata() -> None: """Test a response without usage metadata.""" response = Response( id="resp_123", @@ -1261,12 +1261,12 @@ def test__construct_lc_result_from_response_api_no_usage_metadata() -> None: # No usage field ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert cast(AIMessage, result.generations[0].message).usage_metadata is None -def test__construct_lc_result_from_response_api_web_search_response() -> None: +def test__construct_lc_result_from_responses_api_web_search_response() -> None: """Test a response with web search output.""" from openai.types.responses.response_function_web_search import ( ResponseFunctionWebSearch, @@ -1287,7 +1287,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None: ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert "tool_outputs" in result.generations[0].message.additional_kwargs assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1 @@ -1305,7 +1305,7 @@ def test__construct_lc_result_from_response_api_web_search_response() -> None: ) -def test__construct_lc_result_from_response_api_file_search_response() -> None: +def test__construct_lc_result_from_responses_api_file_search_response() -> None: """Test a response with file search output.""" response = Response( id="resp_123", @@ -1334,7 +1334,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None: ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) assert "tool_outputs" in result.generations[0].message.additional_kwargs assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1 @@ -1375,7 +1375,7 @@ def test__construct_lc_result_from_response_api_file_search_response() -> None: ) -def test__construct_lc_result_from_response_api_mixed_search_responses() -> None: +def test__construct_lc_result_from_responses_api_mixed_search_responses() -> None: """Test a response with both web search and file search outputs.""" response = Response( @@ -1418,7 +1418,7 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None ], ) - result = _construct_lc_result_from_response_api(response) + result = _construct_lc_result_from_responses_api(response) # Check message content assert result.generations[0].message.content == [ @@ -1449,14 +1449,14 @@ def test__construct_lc_result_from_response_api_mixed_search_responses() -> None assert file_search["results"][0]["filename"] == "example.py" -def test__construct_response_api_input_human_message_with_text_blocks_conversion() -> ( +def test__construct_responses_api_input_human_message_with_text_blocks_conversion() -> ( None ): """Test that human messages with text blocks are properly converted.""" messages: list = [ HumanMessage(content=[{"type": "text", "text": "What's in this image?"}]) ] - result = _construct_response_api_input(messages) + result = _construct_responses_api_input(messages) assert len(result) == 1 assert result[0]["role"] == "user" @@ -1466,7 +1466,7 @@ def test__construct_response_api_input_human_message_with_text_blocks_conversion assert result[0]["content"][0]["text"] == "What's in this image?" -def test__construct_response_api_input_human_message_with_image_url_conversion() -> ( +def test__construct_responses_api_input_human_message_with_image_url_conversion() -> ( None ): """Test that human messages with image_url blocks are properly converted.""" @@ -1484,7 +1484,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion() ] ) ] - result = _construct_response_api_input(messages) + result = _construct_responses_api_input(messages) assert len(result) == 1 assert result[0]["role"] == "user" @@ -1501,7 +1501,7 @@ def test__construct_response_api_input_human_message_with_image_url_conversion() assert result[0]["content"][1]["detail"] == "high" -def test__construct_response_api_input_ai_message_with_tool_calls() -> None: +def test__construct_responses_api_input_ai_message_with_tool_calls() -> None: """Test that AI messages with tool calls are properly converted.""" tool_calls = [ { @@ -1521,7 +1521,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None: additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, ) - result = _construct_response_api_input([ai_message]) + result = _construct_responses_api_input([ai_message]) assert len(result) == 1 assert result[0]["type"] == "function_call" @@ -1531,7 +1531,9 @@ def test__construct_response_api_input_ai_message_with_tool_calls() -> None: assert result[0]["id"] == "func_456" -def test__construct_response_api_input_ai_message_with_tool_calls_and_content() -> None: +def test__construct_responses_api_input_ai_message_with_tool_calls_and_content() -> ( + None +): """Test that AI messages with both tool calls and content are properly converted.""" tool_calls = [ { @@ -1551,7 +1553,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content() additional_kwargs={_FUNCTION_CALL_IDS_MAP_KEY: function_call_ids}, ) - result = _construct_response_api_input([ai_message]) + result = _construct_responses_api_input([ai_message]) assert len(result) == 2 @@ -1567,7 +1569,7 @@ def test__construct_response_api_input_ai_message_with_tool_calls_and_content() assert result[1]["id"] == "func_456" -def test__construct_response_api_input_missing_function_call_ids() -> None: +def test__construct_responses_api_input_missing_function_call_ids() -> None: """Test AI messages with tool calls but missing function call IDs raise an error.""" tool_calls = [ { @@ -1581,10 +1583,10 @@ def test__construct_response_api_input_missing_function_call_ids() -> None: ai_message = AIMessage(content="", tool_calls=tool_calls) with pytest.raises(ValueError): - _construct_response_api_input([ai_message]) + _construct_responses_api_input([ai_message]) -def test__construct_response_api_input_tool_message_conversion() -> None: +def test__construct_responses_api_input_tool_message_conversion() -> None: """Test that tool messages are properly converted to function_call_output.""" messages = [ ToolMessage( @@ -1593,7 +1595,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None: ) ] - result = _construct_response_api_input(messages) + result = _construct_responses_api_input(messages) assert len(result) == 1 assert result[0]["type"] == "function_call_output" @@ -1601,7 +1603,7 @@ def test__construct_response_api_input_tool_message_conversion() -> None: assert result[0]["call_id"] == "call_123" -def test__construct_response_api_input_multiple_message_types() -> None: +def test__construct_responses_api_input_multiple_message_types() -> None: """Test conversion of a conversation with multiple message types.""" messages = [ SystemMessage(content="You are a helpful assistant."), @@ -1637,7 +1639,7 @@ def test__construct_response_api_input_multiple_message_types() -> None: ] messages_copy = [m.copy(deep=True) for m in messages] - result = _construct_response_api_input(messages) + result = _construct_responses_api_input(messages) assert len(result) == len(messages)