From f27f64af1916b471548e72f97148787de1e51023 Mon Sep 17 00:00:00 2001 From: "open-swe[bot]" Date: Mon, 11 Aug 2025 20:19:41 +0000 Subject: [PATCH] Apply patch [skip ci] --- .../unit_tests/chat_models/test_batch.py | 598 ++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 libs/partners/openai/tests/unit_tests/chat_models/test_batch.py diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py new file mode 100644 index 00000000000..adcd26eecea --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py @@ -0,0 +1,598 @@ +"""Test OpenAI Batch API functionality.""" + +import json +import time +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch + +import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.batch import ( + BatchError, + OpenAIBatchClient, + OpenAIBatchProcessor, +) + + +class TestOpenAIBatchClient: + """Test the OpenAIBatchClient class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_client = MagicMock() + self.batch_client = OpenAIBatchClient( + client=self.mock_client, + poll_interval=0.1, # Fast polling for tests + timeout=5.0, + ) + + def test_create_batch_success(self): + """Test successful batch creation.""" + # Mock batch creation response + mock_batch = MagicMock() + mock_batch.id = "batch_123" + mock_batch.status = "validating" + self.mock_client.batches.create.return_value = mock_batch + + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}, + } + ] + + batch_id = self.batch_client.create_batch( + batch_requests=batch_requests, + description="Test batch", + metadata={"test": "true"}, + ) + + assert batch_id == "batch_123" + self.mock_client.batches.create.assert_called_once() + + def test_create_batch_failure(self): + """Test batch creation failure.""" + self.mock_client.batches.create.side_effect = Exception("API Error") + + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}, + } + ] + + with pytest.raises(BatchError, match="Failed to create batch"): + self.batch_client.create_batch(batch_requests=batch_requests) + + def test_poll_batch_status_completed(self): + """Test polling until batch completion.""" + # Mock batch status progression + mock_batch_validating = MagicMock() + mock_batch_validating.status = "validating" + + mock_batch_in_progress = MagicMock() + mock_batch_in_progress.status = "in_progress" + + mock_batch_completed = MagicMock() + mock_batch_completed.status = "completed" + mock_batch_completed.output_file_id = "file_123" + + self.mock_client.batches.retrieve.side_effect = [ + mock_batch_validating, + mock_batch_in_progress, + mock_batch_completed, + ] + + result = self.batch_client.poll_batch_status("batch_123") + + assert result.status == "completed" + assert result.output_file_id == "file_123" + assert self.mock_client.batches.retrieve.call_count == 3 + + def test_poll_batch_status_failed(self): + """Test polling when batch fails.""" + mock_batch_failed = MagicMock() + mock_batch_failed.status = "failed" + mock_batch_failed.errors = [{"message": "Batch processing failed"}] + + self.mock_client.batches.retrieve.return_value = mock_batch_failed + + with pytest.raises(BatchError, match="Batch failed"): + self.batch_client.poll_batch_status("batch_123") + + def test_poll_batch_status_timeout(self): + """Test polling timeout.""" + mock_batch_in_progress = MagicMock() + mock_batch_in_progress.status = "in_progress" + + self.mock_client.batches.retrieve.return_value = mock_batch_in_progress + + # Set very short timeout + self.batch_client.timeout = 0.2 + + with pytest.raises(BatchError, match="Batch polling timed out"): + self.batch_client.poll_batch_status("batch_123") + + def test_retrieve_batch_results_success(self): + """Test successful batch result retrieval.""" + # Mock file content + mock_results = [ + { + "id": "batch_req_123", + "custom_id": "request-1", + "response": { + "status_code": 200, + "body": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello! How can I help you?" + } + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} + } + } + } + ] + + mock_file_content = "\n".join(json.dumps(result) for result in mock_results) + self.mock_client.files.content.return_value.content = mock_file_content.encode() + + results = self.batch_client.retrieve_batch_results("file_123") + + assert len(results) == 1 + assert results[0]["custom_id"] == "request-1" + assert results[0]["response"]["status_code"] == 200 + + def test_retrieve_batch_results_failure(self): + """Test batch result retrieval failure.""" + self.mock_client.files.content.side_effect = Exception("File not found") + + with pytest.raises(BatchError, match="Failed to retrieve batch results"): + self.batch_client.retrieve_batch_results("file_123") + + +class TestOpenAIBatchProcessor: + """Test the OpenAIBatchProcessor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_client = MagicMock() + self.processor = OpenAIBatchProcessor( + client=self.mock_client, + model="gpt-3.5-turbo", + poll_interval=0.1, + timeout=5.0, + ) + + def test_create_batch_success(self): + """Test successful batch creation with message conversion.""" + # Mock batch client + with patch.object(self.processor, 'batch_client') as mock_batch_client: + mock_batch_client.create_batch.return_value = "batch_123" + + messages_list = [ + [HumanMessage(content="What is 2+2?")], + [HumanMessage(content="What is the capital of France?")], + ] + + batch_id = self.processor.create_batch( + messages_list=messages_list, + description="Test batch", + metadata={"test": "true"}, + temperature=0.7, + ) + + assert batch_id == "batch_123" + mock_batch_client.create_batch.assert_called_once() + + # Verify batch requests were created correctly + call_args = mock_batch_client.create_batch.call_args + batch_requests = call_args[1]["batch_requests"] + + assert len(batch_requests) == 2 + assert batch_requests[0]["custom_id"] == "request-0" + assert batch_requests[0]["body"]["model"] == "gpt-3.5-turbo" + assert batch_requests[0]["body"]["temperature"] == 0.7 + assert batch_requests[0]["body"]["messages"][0]["role"] == "user" + assert batch_requests[0]["body"]["messages"][0]["content"] == "What is 2+2?" + + def test_poll_batch_status_success(self): + """Test successful batch status polling.""" + with patch.object(self.processor, 'batch_client') as mock_batch_client: + mock_batch = MagicMock() + mock_batch.status = "completed" + mock_batch_client.poll_batch_status.return_value = mock_batch + + result = self.processor.poll_batch_status("batch_123") + + assert result.status == "completed" + mock_batch_client.poll_batch_status.assert_called_once_with( + "batch_123", poll_interval=None, timeout=None + ) + + def test_retrieve_batch_results_success(self): + """Test successful batch result retrieval and conversion.""" + # Mock batch status and results + mock_batch = MagicMock() + mock_batch.status = "completed" + mock_batch.output_file_id = "file_123" + + mock_results = [ + { + "id": "batch_req_123", + "custom_id": "request-0", + "response": { + "status_code": 200, + "body": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "2+2 equals 4." + } + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} + } + } + }, + { + "id": "batch_req_124", + "custom_id": "request-1", + "response": { + "status_code": 200, + "body": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "The capital of France is Paris." + } + } + ], + "usage": {"prompt_tokens": 12, "completion_tokens": 10, "total_tokens": 22} + } + } + } + ] + + with patch.object(self.processor, 'batch_client') as mock_batch_client: + mock_batch_client.poll_batch_status.return_value = mock_batch + mock_batch_client.retrieve_batch_results.return_value = mock_results + + chat_results = self.processor.retrieve_batch_results("batch_123") + + assert len(chat_results) == 2 + + # Check first result + assert isinstance(chat_results[0], ChatResult) + assert len(chat_results[0].generations) == 1 + assert isinstance(chat_results[0].generations[0].message, AIMessage) + assert chat_results[0].generations[0].message.content == "2+2 equals 4." + + # Check second result + assert isinstance(chat_results[1], ChatResult) + assert len(chat_results[1].generations) == 1 + assert isinstance(chat_results[1].generations[0].message, AIMessage) + assert chat_results[1].generations[0].message.content == "The capital of France is Paris." + + def test_retrieve_batch_results_with_errors(self): + """Test batch result retrieval with some failed requests.""" + mock_batch = MagicMock() + mock_batch.status = "completed" + mock_batch.output_file_id = "file_123" + + mock_results = [ + { + "id": "batch_req_123", + "custom_id": "request-0", + "response": { + "status_code": 200, + "body": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Success response" + } + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} + } + } + }, + { + "id": "batch_req_124", + "custom_id": "request-1", + "response": { + "status_code": 400, + "body": { + "error": { + "message": "Invalid request", + "type": "invalid_request_error" + } + } + } + } + ] + + with patch.object(self.processor, 'batch_client') as mock_batch_client: + mock_batch_client.poll_batch_status.return_value = mock_batch + mock_batch_client.retrieve_batch_results.return_value = mock_results + + with pytest.raises(BatchError, match="Batch request request-1 failed"): + self.processor.retrieve_batch_results("batch_123") + + +class TestBaseChatOpenAIBatchMethods: + """Test the batch methods added to BaseChatOpenAI.""" + + def setup_method(self): + """Set up test fixtures.""" + self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key") + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_batch_create_success(self, mock_processor_class): + """Test successful batch creation.""" + mock_processor = MagicMock() + mock_processor.create_batch.return_value = "batch_123" + mock_processor_class.return_value = mock_processor + + messages_list = [ + [HumanMessage(content="What is 2+2?")], + [HumanMessage(content="What is the capital of France?")], + ] + + batch_id = self.llm.batch_create( + messages_list=messages_list, + description="Test batch", + metadata={"test": "true"}, + temperature=0.7, + ) + + assert batch_id == "batch_123" + mock_processor_class.assert_called_once() + mock_processor.create_batch.assert_called_once_with( + messages_list=messages_list, + description="Test batch", + metadata={"test": "true"}, + temperature=0.7, + ) + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_batch_retrieve_success(self, mock_processor_class): + """Test successful batch result retrieval.""" + mock_processor = MagicMock() + mock_chat_results = [ + ChatResult(generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]), + ChatResult(generations=[ChatGeneration(message=AIMessage(content="The capital of France is Paris."))]), + ] + mock_processor.retrieve_batch_results.return_value = mock_chat_results + mock_processor_class.return_value = mock_processor + + results = self.llm.batch_retrieve("batch_123", poll_interval=1.0, timeout=60.0) + + assert len(results) == 2 + assert results[0].generations[0].message.content == "2+2 equals 4." + assert results[1].generations[0].message.content == "The capital of France is Paris." + + mock_processor.poll_batch_status.assert_called_once_with( + batch_id="batch_123", + poll_interval=1.0, + timeout=60.0, + ) + mock_processor.retrieve_batch_results.assert_called_once_with("batch_123") + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_batch_method_with_batch_api_true(self, mock_processor_class): + """Test batch method with use_batch_api=True.""" + mock_processor = MagicMock() + mock_chat_results = [ + ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]), + ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]), + ] + mock_processor.create_batch.return_value = "batch_123" + mock_processor.retrieve_batch_results.return_value = mock_chat_results + mock_processor_class.return_value = mock_processor + + inputs = [ + [HumanMessage(content="Question 1")], + [HumanMessage(content="Question 2")], + ] + + results = self.llm.batch(inputs, use_batch_api=True, temperature=0.5) + + assert len(results) == 2 + assert isinstance(results[0], AIMessage) + assert results[0].content == "Response 1" + assert isinstance(results[1], AIMessage) + assert results[1].content == "Response 2" + + mock_processor.create_batch.assert_called_once() + mock_processor.retrieve_batch_results.assert_called_once() + + def test_batch_method_with_batch_api_false(self): + """Test batch method with use_batch_api=False (default behavior).""" + inputs = [ + [HumanMessage(content="Question 1")], + [HumanMessage(content="Question 2")], + ] + + # Mock the parent class batch method + with patch.object(ChatOpenAI.__bases__[0], 'batch') as mock_super_batch: + mock_super_batch.return_value = [ + AIMessage(content="Response 1"), + AIMessage(content="Response 2"), + ] + + results = self.llm.batch(inputs, use_batch_api=False) + + assert len(results) == 2 + mock_super_batch.assert_called_once_with( + inputs=inputs, + config=None, + return_exceptions=False, + ) + + def test_convert_input_to_messages_list(self): + """Test _convert_input_to_messages helper method.""" + # Test list of messages + messages = [HumanMessage(content="Hello")] + result = self.llm._convert_input_to_messages(messages) + assert result == messages + + # Test single message + message = HumanMessage(content="Hello") + result = self.llm._convert_input_to_messages(message) + assert result == [message] + + # Test string input + result = self.llm._convert_input_to_messages("Hello") + assert len(result) == 1 + assert isinstance(result[0], HumanMessage) + assert result[0].content == "Hello" + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_batch_create_with_error_handling(self, mock_processor_class): + """Test batch creation with error handling.""" + mock_processor = MagicMock() + mock_processor.create_batch.side_effect = BatchError("Batch creation failed") + mock_processor_class.return_value = mock_processor + + messages_list = [[HumanMessage(content="Test")]] + + with pytest.raises(BatchError, match="Batch creation failed"): + self.llm.batch_create(messages_list) + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_batch_retrieve_with_error_handling(self, mock_processor_class): + """Test batch retrieval with error handling.""" + mock_processor = MagicMock() + mock_processor.poll_batch_status.side_effect = BatchError("Batch polling failed") + mock_processor_class.return_value = mock_processor + + with pytest.raises(BatchError, match="Batch polling failed"): + self.llm.batch_retrieve("batch_123") + + def test_batch_method_input_conversion(self): + """Test batch method handles various input formats correctly.""" + with patch.object(self.llm, 'batch_create') as mock_create, \ + patch.object(self.llm, 'batch_retrieve') as mock_retrieve: + + mock_create.return_value = "batch_123" + mock_retrieve.return_value = [ + ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response"))]), + ] + + # Test with string inputs + inputs = ["Hello world"] + results = self.llm.batch(inputs, use_batch_api=True) + + # Verify conversion happened + mock_create.assert_called_once() + call_args = mock_create.call_args[1] + messages_list = call_args['messages_list'] + + assert len(messages_list) == 1 + assert len(messages_list[0]) == 1 + assert isinstance(messages_list[0][0], HumanMessage) + assert messages_list[0][0].content == "Hello world" + + +class TestBatchErrorHandling: + """Test error handling scenarios.""" + + def test_batch_error_creation(self): + """Test BatchError exception creation.""" + error = BatchError("Test error message") + assert str(error) == "Test error message" + + def test_batch_error_with_details(self): + """Test BatchError with additional details.""" + details = {"batch_id": "batch_123", "status": "failed"} + error = BatchError("Batch failed", details) + assert str(error) == "Batch failed" + assert error.args[1] == details + + +class TestBatchIntegrationScenarios: + """Test integration scenarios and edge cases.""" + + def setup_method(self): + """Set up test fixtures.""" + self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key") + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_empty_messages_list(self, mock_processor_class): + """Test handling of empty messages list.""" + mock_processor = MagicMock() + mock_processor.create_batch.return_value = "batch_123" + mock_processor.retrieve_batch_results.return_value = [] + mock_processor_class.return_value = mock_processor + + results = self.llm.batch([], use_batch_api=True) + assert results == [] + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_large_batch_processing(self, mock_processor_class): + """Test processing of large batch.""" + mock_processor = MagicMock() + mock_processor.create_batch.return_value = "batch_123" + + # Create mock results for large batch + num_requests = 100 + mock_chat_results = [ + ChatResult(generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))]) + for i in range(num_requests) + ] + mock_processor.retrieve_batch_results.return_value = mock_chat_results + mock_processor_class.return_value = mock_processor + + inputs = [[HumanMessage(content=f"Question {i}")] for i in range(num_requests)] + results = self.llm.batch(inputs, use_batch_api=True) + + assert len(results) == num_requests + for i, result in enumerate(results): + assert result.content == f"Response {i}" + + @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + def test_mixed_message_types(self, mock_processor_class): + """Test batch processing with mixed message types.""" + mock_processor = MagicMock() + mock_processor.create_batch.return_value = "batch_123" + mock_processor.retrieve_batch_results.return_value = [ + ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]), + ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]), + ] + mock_processor_class.return_value = mock_processor + + inputs = [ + "String input", # Will be converted to HumanMessage + [HumanMessage(content="Direct message list")], # Already formatted + ] + + results = self.llm.batch(inputs, use_batch_api=True) + assert len(results) == 2 + + # Verify the conversion happened correctly + mock_processor.create_batch.assert_called_once() + call_args = mock_processor.create_batch.call_args[1] + messages_list = call_args['messages_list'] + + # First input should be converted to HumanMessage + assert isinstance(messages_list[0][0], HumanMessage) + assert messages_list[0][0].content == "String input" + + # Second input should remain as is + assert isinstance(messages_list[1][0], HumanMessage) + assert messages_list[1][0].content == "Direct message list"