mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 05:47:54 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -24,7 +24,7 @@ pytestmark = pytest.mark.skipif(
|
||||
class TestBatchAPIIntegration:
|
||||
"""Integration tests for OpenAI Batch API functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo",
|
||||
@@ -33,7 +33,7 @@ class TestBatchAPIIntegration:
|
||||
)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_create_and_retrieve_small_batch(self):
|
||||
def test_batch_create_and_retrieve_small_batch(self) -> None:
|
||||
"""Test end-to-end batch processing with a small batch."""
|
||||
# Create a small batch of simple questions
|
||||
messages_list = [
|
||||
@@ -82,7 +82,7 @@ class TestBatchAPIIntegration:
|
||||
assert "paris" in response2.lower()
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_method_with_batch_api_true(self):
|
||||
def test_batch_method_with_batch_api_true(self) -> None:
|
||||
"""Test the batch() method with use_batch_api=True."""
|
||||
inputs = [
|
||||
[HumanMessage(content="Count to 3. Answer with just: 1, 2, 3")],
|
||||
@@ -107,7 +107,7 @@ class TestBatchAPIIntegration:
|
||||
assert "blue" in response2
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_method_comparison(self):
|
||||
def test_batch_method_comparison(self) -> None:
|
||||
"""Test that batch API and standard batch produce similar results."""
|
||||
inputs = [[HumanMessage(content="What is 1+1? Answer with just the number.")]]
|
||||
|
||||
@@ -132,7 +132,7 @@ class TestBatchAPIIntegration:
|
||||
assert "2" in batch_content or "two" in batch_content.lower()
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_with_different_parameters(self):
|
||||
def test_batch_with_different_parameters(self) -> None:
|
||||
"""Test batch processing with different model parameters."""
|
||||
messages_list = [
|
||||
[HumanMessage(content="Write a haiku about coding. Keep it short.")]
|
||||
@@ -160,7 +160,7 @@ class TestBatchAPIIntegration:
|
||||
assert "\n" in result_content or len(result_content.split()) >= 5
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_with_system_message(self):
|
||||
def test_batch_with_system_message(self) -> None:
|
||||
"""Test batch processing with system messages."""
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestBatchAPIIntegration:
|
||||
assert "30" in result_content or "thirty" in result_content.lower()
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_error_handling_invalid_model(self):
|
||||
def test_batch_error_handling_invalid_model(self) -> None:
|
||||
"""Test error handling with invalid model parameters."""
|
||||
# Create a ChatOpenAI instance with an invalid model
|
||||
invalid_llm = ChatOpenAI(model="invalid-model-name-12345", temperature=0.1)
|
||||
@@ -201,7 +201,7 @@ class TestBatchAPIIntegration:
|
||||
# If batch creation succeeds, retrieval should fail
|
||||
invalid_llm.batch_retrieve(batch_id, timeout=300.0)
|
||||
|
||||
def test_batch_input_conversion(self):
|
||||
def test_batch_input_conversion(self) -> None:
|
||||
"""Test batch processing with various input formats."""
|
||||
# Test with string inputs (should be converted to HumanMessage)
|
||||
inputs = [
|
||||
@@ -230,7 +230,7 @@ class TestBatchAPIIntegration:
|
||||
assert "mercury" in response2
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_empty_batch_handling(self):
|
||||
def test_empty_batch_handling(self) -> None:
|
||||
"""Test handling of empty batch inputs."""
|
||||
# Empty inputs should return empty results
|
||||
results = self.llm.batch([], use_batch_api=True)
|
||||
@@ -242,7 +242,7 @@ class TestBatchAPIIntegration:
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_metadata_preservation(self):
|
||||
def test_batch_metadata_preservation(self) -> None:
|
||||
"""Test that batch metadata is properly handled."""
|
||||
messages_list = [[HumanMessage(content="Say 'test successful'")]]
|
||||
|
||||
@@ -270,12 +270,12 @@ class TestBatchAPIIntegration:
|
||||
class TestBatchAPIEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1, max_tokens=50)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_with_very_short_timeout(self):
|
||||
def test_batch_with_very_short_timeout(self) -> None:
|
||||
"""Test batch processing with very short timeout."""
|
||||
messages_list = [[HumanMessage(content="Hello")]]
|
||||
|
||||
@@ -289,13 +289,13 @@ class TestBatchAPIEdgeCases:
|
||||
timeout=5.0, # Very short timeout
|
||||
)
|
||||
|
||||
def test_batch_retrieve_invalid_batch_id(self):
|
||||
def test_batch_retrieve_invalid_batch_id(self) -> None:
|
||||
"""Test retrieving results with invalid batch ID."""
|
||||
with pytest.raises(BatchError):
|
||||
self.llm.batch_retrieve("invalid_batch_id_12345", timeout=30.0)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_with_long_content(self):
|
||||
def test_batch_with_long_content(self) -> None:
|
||||
"""Test batch processing with longer content."""
|
||||
long_content = "Please summarize this text: " + "This is a test sentence. " * 20
|
||||
|
||||
@@ -318,7 +318,7 @@ class TestBatchAPIEdgeCases:
|
||||
class TestBatchAPIPerformance:
|
||||
"""Performance and scalability tests."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo",
|
||||
@@ -327,7 +327,7 @@ class TestBatchAPIPerformance:
|
||||
)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_medium_batch_processing(self):
|
||||
def test_medium_batch_processing(self) -> None:
|
||||
"""Test processing a medium-sized batch (10 requests)."""
|
||||
# Create 10 simple math questions
|
||||
messages_list = [
|
||||
@@ -364,7 +364,7 @@ class TestBatchAPIPerformance:
|
||||
|
||||
# Log processing time for analysis @pytest.mark.scheduled
|
||||
|
||||
def test_batch_vs_sequential_comparison(self):
|
||||
def test_batch_vs_sequential_comparison(self) -> None:
|
||||
"""Compare batch API performance vs sequential processing."""
|
||||
messages = [
|
||||
[HumanMessage(content="Count to 2. Answer: 1, 2")],
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain_openai.chat_models.batch import (
|
||||
class TestOpenAIBatchClient:
|
||||
"""Test the OpenAIBatchClient class."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.mock_client = MagicMock()
|
||||
self.batch_client = OpenAIBatchClient(
|
||||
@@ -27,7 +27,7 @@ class TestOpenAIBatchClient:
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
def test_create_batch_success(self):
|
||||
def test_create_batch_success(self) -> None:
|
||||
"""Test successful batch creation."""
|
||||
# Mock batch creation response
|
||||
mock_batch = MagicMock()
|
||||
@@ -56,7 +56,7 @@ class TestOpenAIBatchClient:
|
||||
assert batch_id == "batch_123"
|
||||
self.mock_client.batches.create.assert_called_once()
|
||||
|
||||
def test_create_batch_failure(self):
|
||||
def test_create_batch_failure(self) -> None:
|
||||
"""Test batch creation failure."""
|
||||
self.mock_client.batches.create.side_effect = Exception("API Error")
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestOpenAIBatchClient:
|
||||
with pytest.raises(BatchError, match="Failed to create batch"):
|
||||
self.batch_client.create_batch(batch_requests=batch_requests)
|
||||
|
||||
def test_poll_batch_status_completed(self):
|
||||
def test_poll_batch_status_completed(self) -> None:
|
||||
"""Test polling until batch completion."""
|
||||
# Mock batch status progression
|
||||
mock_batch_validating = MagicMock()
|
||||
@@ -100,7 +100,7 @@ class TestOpenAIBatchClient:
|
||||
assert result.output_file_id == "file_123"
|
||||
assert self.mock_client.batches.retrieve.call_count == 3
|
||||
|
||||
def test_poll_batch_status_failed(self):
|
||||
def test_poll_batch_status_failed(self) -> None:
|
||||
"""Test polling when batch fails."""
|
||||
mock_batch_failed = MagicMock()
|
||||
mock_batch_failed.status = "failed"
|
||||
@@ -111,7 +111,7 @@ class TestOpenAIBatchClient:
|
||||
with pytest.raises(BatchError, match="Batch failed"):
|
||||
self.batch_client.poll_batch_status("batch_123")
|
||||
|
||||
def test_poll_batch_status_timeout(self):
|
||||
def test_poll_batch_status_timeout(self) -> None:
|
||||
"""Test polling timeout."""
|
||||
mock_batch_in_progress = MagicMock()
|
||||
mock_batch_in_progress.status = "in_progress"
|
||||
@@ -124,7 +124,7 @@ class TestOpenAIBatchClient:
|
||||
with pytest.raises(BatchError, match="Batch polling timed out"):
|
||||
self.batch_client.poll_batch_status("batch_123")
|
||||
|
||||
def test_retrieve_batch_results_success(self):
|
||||
def test_retrieve_batch_results_success(self) -> None:
|
||||
"""Test successful batch result retrieval."""
|
||||
# Mock file content
|
||||
mock_results = [
|
||||
@@ -161,7 +161,7 @@ class TestOpenAIBatchClient:
|
||||
assert results[0]["custom_id"] == "request-1"
|
||||
assert results[0]["response"]["status_code"] == 200
|
||||
|
||||
def test_retrieve_batch_results_failure(self):
|
||||
def test_retrieve_batch_results_failure(self) -> None:
|
||||
"""Test batch result retrieval failure."""
|
||||
self.mock_client.files.content.side_effect = Exception("File not found")
|
||||
|
||||
@@ -172,7 +172,7 @@ class TestOpenAIBatchClient:
|
||||
class TestOpenAIBatchProcessor:
|
||||
"""Test the OpenAIBatchProcessor class."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.mock_client = MagicMock()
|
||||
self.processor = OpenAIBatchProcessor(
|
||||
@@ -182,7 +182,7 @@ class TestOpenAIBatchProcessor:
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
def test_create_batch_success(self):
|
||||
def test_create_batch_success(self) -> None:
|
||||
"""Test successful batch creation with message conversion."""
|
||||
# Mock batch client
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
@@ -214,7 +214,7 @@ class TestOpenAIBatchProcessor:
|
||||
assert batch_requests[0]["body"]["messages"][0]["role"] == "user"
|
||||
assert batch_requests[0]["body"]["messages"][0]["content"] == "What is 2+2?"
|
||||
|
||||
def test_poll_batch_status_success(self):
|
||||
def test_poll_batch_status_success(self) -> None:
|
||||
"""Test successful batch status polling."""
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
mock_batch = MagicMock()
|
||||
@@ -228,7 +228,7 @@ class TestOpenAIBatchProcessor:
|
||||
"batch_123", poll_interval=None, timeout=None
|
||||
)
|
||||
|
||||
def test_retrieve_batch_results_success(self):
|
||||
def test_retrieve_batch_results_success(self) -> None:
|
||||
"""Test successful batch result retrieval and conversion."""
|
||||
# Mock batch status and results
|
||||
mock_batch = MagicMock()
|
||||
@@ -305,7 +305,7 @@ class TestOpenAIBatchProcessor:
|
||||
== "The capital of France is Paris."
|
||||
)
|
||||
|
||||
def test_retrieve_batch_results_with_errors(self):
|
||||
def test_retrieve_batch_results_with_errors(self) -> None:
|
||||
"""Test batch result retrieval with some failed requests."""
|
||||
mock_batch = MagicMock()
|
||||
mock_batch.status = "completed"
|
||||
@@ -360,12 +360,12 @@ class TestOpenAIBatchProcessor:
|
||||
class TestBaseChatOpenAIBatchMethods:
|
||||
"""Test the batch methods added to BaseChatOpenAI."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_create_success(self, mock_processor_class):
|
||||
def test_batch_create_success(self, mock_processor_class) -> None:
|
||||
"""Test successful batch creation."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
@@ -393,7 +393,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
)
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_retrieve_success(self, mock_processor_class):
|
||||
def test_batch_retrieve_success(self, mock_processor_class) -> None:
|
||||
"""Test successful batch result retrieval."""
|
||||
mock_processor = MagicMock()
|
||||
mock_chat_results = [
|
||||
@@ -426,7 +426,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
mock_processor.retrieve_batch_results.assert_called_once_with("batch_123")
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_method_with_batch_api_true(self, mock_processor_class):
|
||||
def test_batch_method_with_batch_api_true(self, mock_processor_class) -> None:
|
||||
"""Test batch method with use_batch_api=True."""
|
||||
mock_processor = MagicMock()
|
||||
mock_chat_results = [
|
||||
@@ -457,7 +457,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
mock_processor.create_batch.assert_called_once()
|
||||
mock_processor.retrieve_batch_results.assert_called_once()
|
||||
|
||||
def test_batch_method_with_batch_api_false(self):
|
||||
def test_batch_method_with_batch_api_false(self) -> None:
|
||||
"""Test batch method with use_batch_api=False (default behavior)."""
|
||||
inputs = [
|
||||
[HumanMessage(content="Question 1")],
|
||||
@@ -478,7 +478,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
inputs=inputs, config=None, return_exceptions=False
|
||||
)
|
||||
|
||||
def test_convert_input_to_messages_list(self):
|
||||
def test_convert_input_to_messages_list(self) -> None:
|
||||
"""Test _convert_input_to_messages helper method."""
|
||||
# Test list of messages
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
@@ -497,7 +497,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_create_with_error_handling(self, mock_processor_class):
|
||||
def test_batch_create_with_error_handling(self, mock_processor_class) -> None:
|
||||
"""Test batch creation with error handling."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.side_effect = BatchError("Batch creation failed")
|
||||
@@ -509,7 +509,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
self.llm.batch_create(messages_list)
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_retrieve_with_error_handling(self, mock_processor_class):
|
||||
def test_batch_retrieve_with_error_handling(self, mock_processor_class) -> None:
|
||||
"""Test batch retrieval with error handling."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.poll_batch_status.side_effect = BatchError(
|
||||
@@ -520,7 +520,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
with pytest.raises(BatchError, match="Batch polling failed"):
|
||||
self.llm.batch_retrieve("batch_123")
|
||||
|
||||
def test_batch_method_input_conversion(self):
|
||||
def test_batch_method_input_conversion(self) -> None:
|
||||
"""Test batch method handles various input formats correctly."""
|
||||
with (
|
||||
patch.object(self.llm, "batch_create") as mock_create,
|
||||
@@ -551,12 +551,12 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
class TestBatchErrorHandling:
|
||||
"""Test error handling scenarios."""
|
||||
|
||||
def test_batch_error_creation(self):
|
||||
def test_batch_error_creation(self) -> None:
|
||||
"""Test BatchError exception creation."""
|
||||
error = BatchError("Test error message")
|
||||
assert str(error) == "Test error message"
|
||||
|
||||
def test_batch_error_with_details(self):
|
||||
def test_batch_error_with_details(self) -> None:
|
||||
"""Test BatchError with additional details."""
|
||||
details = {"batch_id": "batch_123", "status": "failed"}
|
||||
error = BatchError("Batch failed", details)
|
||||
@@ -567,12 +567,12 @@ class TestBatchErrorHandling:
|
||||
class TestBatchIntegrationScenarios:
|
||||
"""Test integration scenarios and edge cases."""
|
||||
|
||||
def setup_method(self):
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_empty_messages_list(self, mock_processor_class):
|
||||
def test_empty_messages_list(self, mock_processor_class) -> None:
|
||||
"""Test handling of empty messages list."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
@@ -583,7 +583,7 @@ class TestBatchIntegrationScenarios:
|
||||
assert results == []
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_large_batch_processing(self, mock_processor_class):
|
||||
def test_large_batch_processing(self, mock_processor_class) -> None:
|
||||
"""Test processing of large batch."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
@@ -607,7 +607,7 @@ class TestBatchIntegrationScenarios:
|
||||
assert result.content == f"Response {i}"
|
||||
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_mixed_message_types(self, mock_processor_class):
|
||||
def test_mixed_message_types(self, mock_processor_class) -> None:
|
||||
"""Test batch processing with mixed message types."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
|
||||
Reference in New Issue
Block a user