mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 05:47:54 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -302,8 +302,7 @@ class TestBatchAPIEdgeCases:
|
||||
messages_list = [[HumanMessage(content=long_content)]]
|
||||
|
||||
batch_id = self.llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
max_tokens=200, # Allow more tokens for summary
|
||||
messages_list=messages_list, # Allow more tokens for summary
|
||||
)
|
||||
|
||||
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
|
||||
@@ -322,8 +321,7 @@ class TestBatchAPIPerformance:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
max_tokens=30, # Keep responses short
|
||||
temperature=0.1, # Keep responses short
|
||||
)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
|
||||
@@ -15,17 +15,13 @@ from langchain_openai.chat_models.batch import (
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIBatchClient:
|
||||
class TestOpenAIBatchProcessor:
|
||||
"""Test the OpenAIBatchClient class."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.mock_client = MagicMock()
|
||||
self.batch_client = OpenAIBatchClient(
|
||||
client=self.mock_client,
|
||||
poll_interval=0.1, # Fast polling for tests
|
||||
timeout=5.0,
|
||||
)
|
||||
self.batch_processor = OpenAIBatchProcessor(client=self.mock_client, model="gpt-3.5-turbo")
|
||||
|
||||
def test_create_batch_success(self) -> None:
|
||||
"""Test successful batch creation."""
|
||||
@@ -47,8 +43,8 @@ class TestOpenAIBatchClient:
|
||||
}
|
||||
]
|
||||
|
||||
batch_id = self.batch_client.create_batch(
|
||||
requests=batch_requests, description="Test batch", metadata={"test": "true"}
|
||||
batch_id = self.batch_processor.create_batch(
|
||||
batch_requests=batch_requests, description="Test batch", metadata={"test": "true"}
|
||||
)
|
||||
|
||||
assert batch_id == "batch_123"
|
||||
@@ -71,7 +67,7 @@ class TestOpenAIBatchClient:
|
||||
]
|
||||
|
||||
with pytest.raises(BatchError, match="Failed to create batch"):
|
||||
self.batch_client.create_batch(requests=batch_requests)
|
||||
self.batch_processor.create_batch(batch_requests=batch_requests)
|
||||
|
||||
def test_poll_batch_status_completed(self) -> None:
|
||||
"""Test polling until batch completion."""
|
||||
@@ -92,7 +88,7 @@ class TestOpenAIBatchClient:
|
||||
mock_batch_completed,
|
||||
]
|
||||
|
||||
result = self.batch_client.poll_batch_status("batch_123")
|
||||
result = self.batch_processor.poll_batch_status("batch_123")
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.output_file_id == "file_123"
|
||||
@@ -107,7 +103,7 @@ class TestOpenAIBatchClient:
|
||||
self.mock_client.batches.retrieve.return_value = mock_batch_failed
|
||||
|
||||
with pytest.raises(BatchError, match="Batch failed"):
|
||||
self.batch_client.poll_batch_status("batch_123")
|
||||
self.batch_processor.poll_batch_status("batch_123")
|
||||
|
||||
def test_poll_batch_status_timeout(self) -> None:
|
||||
"""Test polling timeout."""
|
||||
@@ -119,7 +115,7 @@ class TestOpenAIBatchClient:
|
||||
# Set very short timeout
|
||||
self.batch_
|
||||
with pytest.raises(BatchError, match="Batch polling timed out"):
|
||||
self.batch_client.poll_batch_status("batch_123")
|
||||
self.batch_processor.poll_batch_status("batch_123")
|
||||
|
||||
def test_retrieve_batch_results_success(self) -> None:
|
||||
"""Test successful batch result retrieval."""
|
||||
@@ -152,7 +148,7 @@ class TestOpenAIBatchClient:
|
||||
mock_file_content = "\n".join(json.dumps(result) for result in mock_results)
|
||||
self.mock_client.files.content.return_value.content = mock_file_content.encode()
|
||||
|
||||
results = self.batch_client.retrieve_batch_results("file_123")
|
||||
results = self.batch_processor.retrieve_batch_results("file_123")
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["custom_id"] == "request-1"
|
||||
@@ -163,7 +159,7 @@ class TestOpenAIBatchClient:
|
||||
self.mock_client.files.content.side_effect = Exception("File not found")
|
||||
|
||||
with pytest.raises(BatchError, match="Failed to retrieve batch results"):
|
||||
self.batch_client.retrieve_batch_results("file_123")
|
||||
self.batch_processor.retrieve_batch_results("file_123")
|
||||
|
||||
|
||||
class TestOpenAIBatchProcessor:
|
||||
@@ -454,8 +450,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
mock_processor.create_batch.assert_called_once()
|
||||
mock_processor.retrieve_batch_results.assert_called_once()
|
||||
|
||||
def test_batch_method_with_batch_api_false(self) -> None:
|
||||
"""Test batch method with use_batch_api=False (default behavior)."""
|
||||
default behavior)."""
|
||||
inputs = [
|
||||
[HumanMessage(content="Question 1")],
|
||||
[HumanMessage(content="Question 2")],
|
||||
@@ -517,37 +512,6 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
with pytest.raises(BatchError, match="Batch polling failed"):
|
||||
self.llm.batch_retrieve("batch_123")
|
||||
|
||||
def test_batch_method_input_conversion(self) -> None:
|
||||
"""Test batch method handles various input formats correctly."""
|
||||
with (
|
||||
patch.object(self.llm, "batch_create") as mock_create,
|
||||
patch.object(self.llm, "batch_retrieve") as mock_retrieve,
|
||||
):
|
||||
mock_create.return_value = "batch_123"
|
||||
mock_retrieve.return_value = [
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response"))]
|
||||
)
|
||||
]
|
||||
|
||||
# Test with string inputs
|
||||
inputs = ["Hello world"]
|
||||
_ = self.llm.batch(inputs, use_batch_api=True)
|
||||
|
||||
# Verify conversion happened
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args[1]
|
||||
messages_list = call_args["messages_list"]
|
||||
|
||||
assert len(messages_list) == 1
|
||||
assert len(messages_list[0]) == 1
|
||||
assert isinstance(messages_list[0][0], HumanMessage)
|
||||
assert messages_list[0][0].content == "Hello world"
|
||||
|
||||
|
||||
class TestBatchErrorHandling:
|
||||
"""Test error handling scenarios."""
|
||||
|
||||
def test_batch_error_creation(self) -> None:
|
||||
"""Test BatchError exception creation."""
|
||||
error = BatchError("Test error message")
|
||||
|
||||
Reference in New Issue
Block a user