diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py index 969f7ab935a..5601c523e2d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py @@ -302,8 +302,7 @@ class TestBatchAPIEdgeCases: messages_list = [[HumanMessage(content=long_content)]] batch_id = self.llm.batch_create( - messages_list=messages_list, - max_tokens=200, # Allow more tokens for summary + messages_list=messages_list, # Allow more tokens for summary ) results = self.llm.batch_retrieve(batch_id, timeout=1800.0) @@ -322,8 +321,7 @@ class TestBatchAPIPerformance: """Set up test fixtures.""" self.llm = ChatOpenAI( model="gpt-3.5-turbo", - temperature=0.1, - max_tokens=30, # Keep responses short + temperature=0.1, # Keep responses short ) @pytest.mark.scheduled diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py index 9c2dedf20c6..2c84fdf87d7 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py @@ -15,17 +15,13 @@ from langchain_openai.chat_models.batch import ( ) -class TestOpenAIBatchClient: +class TestOpenAIBatchProcessor: """Test the OpenAIBatchClient class.""" def setup_method(self) -> None: """Set up test fixtures.""" self.mock_client = MagicMock() - self.batch_client = OpenAIBatchClient( - client=self.mock_client, - poll_interval=0.1, # Fast polling for tests - timeout=5.0, - ) + self.batch_processor = OpenAIBatchProcessor(client=self.mock_client, model="gpt-3.5-turbo") def test_create_batch_success(self) -> None: """Test successful batch creation.""" @@ -47,8 +43,8 @@ class TestOpenAIBatchClient: } ] - batch_id = self.batch_client.create_batch( - requests=batch_requests, description="Test batch", metadata={"test": "true"} + batch_id = self.batch_processor.create_batch( + batch_requests=batch_requests, description="Test batch", metadata={"test": "true"} ) assert batch_id == "batch_123" @@ -71,7 +67,7 @@ class TestOpenAIBatchClient: ] with pytest.raises(BatchError, match="Failed to create batch"): - self.batch_client.create_batch(requests=batch_requests) + self.batch_processor.create_batch(batch_requests=batch_requests) def test_poll_batch_status_completed(self) -> None: """Test polling until batch completion.""" @@ -92,7 +88,7 @@ class TestOpenAIBatchClient: mock_batch_completed, ] - result = self.batch_client.poll_batch_status("batch_123") + result = self.batch_processor.poll_batch_status("batch_123") assert result.status == "completed" assert result.output_file_id == "file_123" @@ -107,7 +103,7 @@ class TestOpenAIBatchClient: self.mock_client.batches.retrieve.return_value = mock_batch_failed with pytest.raises(BatchError, match="Batch failed"): - self.batch_client.poll_batch_status("batch_123") + self.batch_processor.poll_batch_status("batch_123") def test_poll_batch_status_timeout(self) -> None: """Test polling timeout.""" @@ -119,7 +115,7 @@ class TestOpenAIBatchClient: # Set very short timeout self.batch_ with pytest.raises(BatchError, match="Batch polling timed out"): - self.batch_client.poll_batch_status("batch_123") + self.batch_processor.poll_batch_status("batch_123") def test_retrieve_batch_results_success(self) -> None: """Test successful batch result retrieval.""" @@ -152,7 +148,7 @@ class TestOpenAIBatchClient: mock_file_content = "\n".join(json.dumps(result) for result in mock_results) self.mock_client.files.content.return_value.content = mock_file_content.encode() - results = self.batch_client.retrieve_batch_results("file_123") + results = self.batch_processor.retrieve_batch_results("file_123") assert len(results) == 1 assert results[0]["custom_id"] == "request-1" @@ -163,7 +159,7 @@ class TestOpenAIBatchClient: self.mock_client.files.content.side_effect = Exception("File not found") with pytest.raises(BatchError, match="Failed to retrieve batch results"): - self.batch_client.retrieve_batch_results("file_123") + self.batch_processor.retrieve_batch_results("file_123") class TestOpenAIBatchProcessor: @@ -454,8 +450,7 @@ class TestBaseChatOpenAIBatchMethods: mock_processor.create_batch.assert_called_once() mock_processor.retrieve_batch_results.assert_called_once() - def test_batch_method_with_batch_api_false(self) -> None: - """Test batch method with use_batch_api=False (default behavior).""" + default behavior).""" inputs = [ [HumanMessage(content="Question 1")], [HumanMessage(content="Question 2")], @@ -517,37 +512,6 @@ class TestBaseChatOpenAIBatchMethods: with pytest.raises(BatchError, match="Batch polling failed"): self.llm.batch_retrieve("batch_123") - def test_batch_method_input_conversion(self) -> None: - """Test batch method handles various input formats correctly.""" - with ( - patch.object(self.llm, "batch_create") as mock_create, - patch.object(self.llm, "batch_retrieve") as mock_retrieve, - ): - mock_create.return_value = "batch_123" - mock_retrieve.return_value = [ - ChatResult( - generations=[ChatGeneration(message=AIMessage(content="Response"))] - ) - ] - - # Test with string inputs - inputs = ["Hello world"] - _ = self.llm.batch(inputs, use_batch_api=True) - - # Verify conversion happened - mock_create.assert_called_once() - call_args = mock_create.call_args[1] - messages_list = call_args["messages_list"] - - assert len(messages_list) == 1 - assert len(messages_list[0]) == 1 - assert isinstance(messages_list[0][0], HumanMessage) - assert messages_list[0][0].content == "Hello world" - - -class TestBatchErrorHandling: - """Test error handling scenarios.""" - def test_batch_error_creation(self) -> None: """Test BatchError exception creation.""" error = BatchError("Test error message")