Apply patch [skip ci]

This commit is contained in:
open-swe[bot]
2025-08-11 20:33:56 +00:00
parent cc28873253
commit c5c43e3ced
6 changed files with 466 additions and 441 deletions

View File

@@ -1,4 +1,9 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, BatchError, BatchStatus
from langchain_openai.chat_models import (
AzureChatOpenAI,
BatchError,
BatchStatus,
ChatOpenAI,
)
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.tools import custom_tool
@@ -14,4 +19,3 @@ __all__ = [
"BatchError",
"BatchStatus",
]

View File

@@ -3,4 +3,3 @@ from langchain_openai.chat_models.base import ChatOpenAI
from langchain_openai.chat_models.batch import BatchError, BatchStatus
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "BatchError", "BatchStatus"]

View File

@@ -1503,142 +1503,142 @@ class BaseChatOpenAI(BaseChatModel):
) -> int:
"""Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package.
**Requirements**: You must have the ``pillow`` installed if you want to count
image tokens if you are specifying the image as a base64 string, and you must
have both ``pillow`` and ``httpx`` installed if you are specifying the image
as a URL. If these aren't installed image inputs will be ignored in token
counting.
**Requirements**: You must have the ``pillow`` installed if you want to count
image tokens if you are specifying the image as a base64 string, and you must
have both ``pillow`` and ``httpx`` installed if you are specifying the image
as a URL. If these aren't installed image inputs will be ignored in token
counting.
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
.. dropdown:: Batch API for cost savings
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
.. dropdown:: Batch API for cost savings
.. versionadded:: 0.3.7
.. versionadded:: 0.3.7
OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by
processing requests asynchronously. This is ideal for tasks like data processing,
content generation, or evaluation that don't require immediate responses.
OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by
processing requests asynchronously. This is ideal for tasks like data processing,
content generation, or evaluation that don't require immediate responses.
**Cost vs Latency Tradeoff:**
**Cost vs Latency Tradeoff:**
- **Standard API**: Immediate results, full pricing
- **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours)
- **Standard API**: Immediate results, full pricing
- **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours)
**Method 1: Direct batch management**
**Method 1: Direct batch management**
Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle:
Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle:
.. code-block:: python
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
llm = ChatOpenAI(model="gpt-3.5-turbo")
llm = ChatOpenAI(model="gpt-3.5-turbo")
# Prepare multiple message sequences for batch processing
messages_list = [
[HumanMessage(content="Translate 'hello' to French")],
[HumanMessage(content="Translate 'goodbye' to Spanish")],
[HumanMessage(content="What is the capital of Italy?")],
]
# Prepare multiple message sequences for batch processing
messages_list = [
[HumanMessage(content="Translate 'hello' to French")],
[HumanMessage(content="Translate 'goodbye' to Spanish")],
[HumanMessage(content="What is the capital of Italy?")],
]
# Create batch job (returns immediately with batch ID)
batch_id = llm.batch_create(
messages_list=messages_list,
description="Translation and geography batch",
metadata={"project": "multilingual_qa", "user": "analyst_1"},
)
print(f"Batch created: {batch_id}")
# Create batch job (returns immediately with batch ID)
batch_id = llm.batch_create(
messages_list=messages_list,
description="Translation and geography batch",
metadata={"project": "multilingual_qa", "user": "analyst_1"},
)
print(f"Batch created: {batch_id}")
# Later, retrieve results (polls until completion)
results = llm.batch_retrieve(
batch_id=batch_id,
poll_interval=60.0, # Check every minute
timeout=3600.0, # 1 hour timeout
)
# Later, retrieve results (polls until completion)
results = llm.batch_retrieve(
batch_id=batch_id,
poll_interval=60.0, # Check every minute
timeout=3600.0, # 1 hour timeout
)
# Process results
for i, result in enumerate(results):
response = result.generations[0].message.content
print(f"Response {i+1}: {response}")
# Process results
for i, result in enumerate(results):
response = result.generations[0].message.content
print(f"Response {i + 1}: {response}")
**Method 2: Enhanced batch() method**
**Method 2: Enhanced batch() method**
Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration:
Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration:
.. code-block:: python
.. code-block:: python
# Standard batch processing (immediate, full cost)
inputs = [
[HumanMessage(content="What is 2+2?")],
[HumanMessage(content="What is 3+3?")],
]
standard_results = llm.batch(inputs) # Default: use_batch_api=False
# Standard batch processing (immediate, full cost)
inputs = [
[HumanMessage(content="What is 2+2?")],
[HumanMessage(content="What is 3+3?")],
]
standard_results = llm.batch(inputs) # Default: use_batch_api=False
# Batch API processing (50% cost savings, polling)
batch_results = llm.batch(
inputs,
use_batch_api=True, # Enable cost savings
poll_interval=30.0, # Poll every 30 seconds
timeout=1800.0, # 30 minute timeout
)
# Batch API processing (50% cost savings, polling)
batch_results = llm.batch(
inputs,
use_batch_api=True, # Enable cost savings
poll_interval=30.0, # Poll every 30 seconds
timeout=1800.0, # 30 minute timeout
)
**Batch creation with custom parameters:**
**Batch creation with custom parameters:**
.. code-block:: python
.. code-block:: python
# Create batch with specific model parameters
batch_id = llm.batch_create(
messages_list=messages_list,
description="Creative writing batch",
metadata={"task_type": "content_generation"},
temperature=0.8, # Higher creativity
max_tokens=200, # Longer responses
top_p=0.9, # Nucleus sampling
)
# Create batch with specific model parameters
batch_id = llm.batch_create(
messages_list=messages_list,
description="Creative writing batch",
metadata={"task_type": "content_generation"},
temperature=0.8, # Higher creativity
max_tokens=200, # Longer responses
top_p=0.9, # Nucleus sampling
)
**Error handling and monitoring:**
**Error handling and monitoring:**
.. code-block:: python
.. code-block:: python
from langchain_openai.chat_models.batch import BatchError
from langchain_openai.chat_models.batch import BatchError
try:
batch_id = llm.batch_create(messages_list)
results = llm.batch_retrieve(batch_id, timeout=600.0)
except BatchError as e:
print(f"Batch processing failed: {e}")
# Handle batch failure (retry, fallback to standard API, etc.)
try:
batch_id = llm.batch_create(messages_list)
results = llm.batch_retrieve(batch_id, timeout=600.0)
except BatchError as e:
print(f"Batch processing failed: {e}")
# Handle batch failure (retry, fallback to standard API, etc.)
**Best practices:**
**Best practices:**
- Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times
- Set appropriate **timeouts** based on batch size (larger batches take longer)
- Include **descriptive metadata** for tracking and debugging batch jobs
- Consider **fallback strategies** for time-sensitive applications
- Monitor batch status for **long-running jobs** to detect failures early
- Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times
- Set appropriate **timeouts** based on batch size (larger batches take longer)
- Include **descriptive metadata** for tracking and debugging batch jobs
- Consider **fallback strategies** for time-sensitive applications
- Monitor batch status for **long-running jobs** to detect failures early
**When to use Batch API:**
**When to use Batch API:**
✅ **Good for:**
- Data processing and analysis
- Content generation at scale
- Model evaluation and testing
- Batch translation or summarization
- Non-interactive applications
✅ **Good for:**
- Data processing and analysis
- Content generation at scale
- Model evaluation and testing
- Batch translation or summarization
- Non-interactive applications
❌ **Not suitable for:**
- Real-time chat applications
- Interactive user interfaces
- Time-critical decision making
- Applications requiring immediate responses
❌ **Not suitable for:**
- Real-time chat applications
- Interactive user interfaces
- Time-critical decision making
- Applications requiring immediate responses
""" # noqa: E501
# TODO: Count bound tools as part of input.
if tools is not None:
@@ -2147,6 +2147,7 @@ class BaseChatOpenAI(BaseChatModel):
else:
filtered[k] = v
return filtered
def batch_create(
self,
messages_list: List[List[BaseMessage]],
@@ -2159,10 +2160,10 @@ class BaseChatOpenAI(BaseChatModel):
) -> str:
"""
Create a batch job using OpenAI's Batch API for asynchronous processing.
This method provides 50% cost savings compared to the standard API in exchange
for asynchronous processing with polling for results.
Args:
messages_list: List of message sequences to process in batch.
description: Optional description for the batch job.
@@ -2170,34 +2171,34 @@ class BaseChatOpenAI(BaseChatModel):
poll_interval: Default time in seconds between status checks when polling.
timeout: Default maximum time in seconds to wait for completion.
**kwargs: Additional parameters to pass to chat completions.
Returns:
The batch ID for tracking the asynchronous job.
Raises:
BatchError: If batch creation fails.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
llm = ChatOpenAI()
messages_list = [
[HumanMessage(content="What is 2+2?")],
[HumanMessage(content="What is the capital of France?")],
]
# Create batch job (50% cost savings)
batch_id = llm.batch_create(messages_list)
# Later, retrieve results
results = llm.batch_retrieve(batch_id)
"""
# Import here to avoid circular imports
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
# Create batch processor with current model settings
processor = OpenAIBatchProcessor(
client=self.root_client,
@@ -2205,19 +2206,19 @@ class BaseChatOpenAI(BaseChatModel):
poll_interval=poll_interval,
timeout=timeout,
)
# Filter and prepare kwargs for batch processing
batch_kwargs = self._get_invocation_params(**kwargs)
# Remove model from kwargs since it's handled by the processor
batch_kwargs.pop("model", None)
return processor.create_batch(
messages_list=messages_list,
description=description,
metadata=metadata,
**batch_kwargs,
)
def batch_retrieve(
self,
batch_id: str,
@@ -2227,36 +2228,36 @@ class BaseChatOpenAI(BaseChatModel):
) -> List[ChatResult]:
"""
Retrieve results from a batch job, polling until completion if necessary.
This method will poll the batch status until completion and return the results
converted to LangChain ChatResult format.
Args:
batch_id: The batch ID returned from batch_create().
poll_interval: Time in seconds between status checks. Uses default if None.
timeout: Maximum time in seconds to wait. Uses default if None.
Returns:
List of ChatResult objects corresponding to the original message sequences.
Raises:
BatchError: If batch retrieval fails, times out, or batch job failed.
Example:
.. code-block:: python
# After creating a batch job
batch_id = llm.batch_create(messages_list)
# Retrieve results (will poll until completion)
results = llm.batch_retrieve(batch_id)
for result in results:
print(result.generations[0].message.content)
"""
# Import here to avoid circular imports
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
# Create batch processor with current model settings
processor = OpenAIBatchProcessor(
client=self.root_client,
@@ -2264,15 +2265,14 @@ class BaseChatOpenAI(BaseChatModel):
poll_interval=poll_interval or 10.0,
timeout=timeout,
)
# Poll for completion and retrieve results
processor.poll_batch_status(
batch_id=batch_id,
poll_interval=poll_interval,
timeout=timeout,
batch_id=batch_id, poll_interval=poll_interval, timeout=timeout
)
return processor.retrieve_batch_results(batch_id)
@override
def batch(
self,
@@ -2285,11 +2285,11 @@ class BaseChatOpenAI(BaseChatModel):
) -> List[BaseMessage]:
"""
Batch process multiple inputs using either standard API or OpenAI Batch API.
This method provides two processing modes:
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
Args:
inputs: List of inputs to process in batch.
config: Configuration for the batch processing.
@@ -2297,33 +2297,33 @@ class BaseChatOpenAI(BaseChatModel):
use_batch_api: If True, use OpenAI's Batch API for cost savings with polling.
If False (default), use standard parallel processing for immediate results.
**kwargs: Additional parameters to pass to the underlying model.
Returns:
List of BaseMessage objects corresponding to the inputs.
Raises:
BatchError: If batch processing fails (when use_batch_api=True).
Note:
**Cost vs Latency Tradeoff:**
- use_batch_api=False: Immediate results, standard API pricing
- use_batch_api=True: 50% cost savings, asynchronous processing with polling
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
llm = ChatOpenAI()
inputs = [
[HumanMessage(content="What is 2+2?")],
[HumanMessage(content="What is the capital of France?")],
]
# Standard processing (immediate results)
results = llm.batch(inputs)
# Batch API processing (50% cost savings, polling required)
results = llm.batch(inputs, use_batch_api=True)
"""
@@ -2338,11 +2338,11 @@ class BaseChatOpenAI(BaseChatModel):
# Convert single input to list of messages
messages = self._convert_input_to_messages(input_item)
messages_list.append(messages)
# Create batch job and poll for results
batch_id = self.batch_create(messages_list, **kwargs)
chat_results = self.batch_retrieve(batch_id)
# Convert ChatResult objects to BaseMessage objects
return [result.generations[0].message for result in chat_results]
else:
@@ -2354,7 +2354,9 @@ class BaseChatOpenAI(BaseChatModel):
**kwargs,
)
def _convert_input_to_messages(self, input_item: LanguageModelInput) -> List[BaseMessage]:
def _convert_input_to_messages(
self, input_item: LanguageModelInput
) -> List[BaseMessage]:
"""Convert various input formats to a list of BaseMessage objects."""
if isinstance(input_item, list):
# Already a list of messages
@@ -2365,27 +2367,17 @@ class BaseChatOpenAI(BaseChatModel):
elif isinstance(input_item, str):
# String input - convert to HumanMessage
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_item)]
elif hasattr(input_item, 'to_messages'):
elif hasattr(input_item, "to_messages"):
# PromptValue or similar
return input_item.to_messages()
else:
# Try to convert to string and then to HumanMessage
from langchain_core.messages import HumanMessage
return [HumanMessage(content=str(input_item))]
def _get_generation_chunk_from_completion(
self, completion: openai.BaseModel
) -> ChatGenerationChunk:

View File

@@ -5,19 +5,22 @@ from __future__ import annotations
import json
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from uuid import uuid4
import openai
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_openai.chat_models._compat import convert_dict_to_message, convert_message_to_dict
from langchain_openai.chat_models._compat import (
convert_dict_to_message,
convert_message_to_dict,
)
class BatchStatus(str, Enum):
"""OpenAI Batch API status values."""
VALIDATING = "validating"
FAILED = "failed"
IN_PROGRESS = "in_progress"
@@ -30,8 +33,10 @@ class BatchStatus(str, Enum):
class BatchError(Exception):
"""Exception raised when batch processing fails."""
def __init__(self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None):
def __init__(
self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None
):
super().__init__(message)
self.batch_id = batch_id
self.status = status
@@ -39,22 +44,22 @@ class BatchError(Exception):
class OpenAIBatchClient:
"""
OpenAI Batch API client wrapper that handles batch creation, status polling,
OpenAI Batch API client wrapper that handles batch creation, status polling,
and result retrieval.
This class provides a high-level interface to OpenAI's Batch API, which offers
50% cost savings compared to the standard API in exchange for asynchronous processing.
"""
def __init__(self, client: openai.OpenAI):
"""
Initialize the batch client.
Args:
client: OpenAI client instance to use for API calls.
"""
self.client = client
def create_batch(
self,
requests: List[Dict[str, Any]],
@@ -63,53 +68,52 @@ class OpenAIBatchClient:
) -> str:
"""
Create a new batch job with the OpenAI Batch API.
Args:
requests: List of request objects in OpenAI batch format.
description: Optional description for the batch job.
metadata: Optional metadata to attach to the batch job.
Returns:
The batch ID for tracking the job.
Raises:
BatchError: If batch creation fails.
"""
try:
# Create JSONL content for the batch
jsonl_content = "\n".join(json.dumps(req) for req in requests)
# Upload the batch file
file_response = self.client.files.create(
file=jsonl_content.encode('utf-8'),
purpose="batch"
file=jsonl_content.encode("utf-8"), purpose="batch"
)
# Create the batch job
batch_response = self.client.batches.create(
input_file_id=file_response.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata=metadata or {}
metadata=metadata or {},
)
return batch_response.id
except openai.OpenAIError as e:
raise BatchError(f"Failed to create batch: {e}") from e
except Exception as e:
raise BatchError(f"Unexpected error creating batch: {e}") from e
def retrieve_batch(self, batch_id: str) -> Dict[str, Any]:
"""
Retrieve batch information by ID.
Args:
batch_id: The batch ID to retrieve.
Returns:
Dictionary containing batch information including status.
Raises:
BatchError: If batch retrieval fails.
"""
@@ -119,20 +123,24 @@ class OpenAIBatchClient:
"id": batch.id,
"status": batch.status,
"created_at": batch.created_at,
"completed_at": getattr(batch, 'completed_at', None),
"failed_at": getattr(batch, 'failed_at', None),
"expired_at": getattr(batch, 'expired_at', None),
"request_counts": getattr(batch, 'request_counts', {}),
"metadata": getattr(batch, 'metadata', {}),
"errors": getattr(batch, 'errors', None),
"output_file_id": getattr(batch, 'output_file_id', None),
"error_file_id": getattr(batch, 'error_file_id', None),
"completed_at": getattr(batch, "completed_at", None),
"failed_at": getattr(batch, "failed_at", None),
"expired_at": getattr(batch, "expired_at", None),
"request_counts": getattr(batch, "request_counts", {}),
"metadata": getattr(batch, "metadata", {}),
"errors": getattr(batch, "errors", None),
"output_file_id": getattr(batch, "output_file_id", None),
"error_file_id": getattr(batch, "error_file_id", None),
}
except openai.OpenAIError as e:
raise BatchError(f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id
) from e
except Exception as e:
raise BatchError(f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id
) from e
def poll_batch_status(
self,
batch_id: str,
@@ -141,94 +149,106 @@ class OpenAIBatchClient:
) -> Dict[str, Any]:
"""
Poll batch status until completion or failure.
Args:
batch_id: The batch ID to poll.
poll_interval: Time in seconds between status checks.
timeout: Maximum time in seconds to wait. None for no timeout.
Returns:
Final batch information when completed.
Raises:
BatchError: If batch fails or times out.
"""
start_time = time.time()
while True:
batch_info = self.retrieve_batch(batch_id)
status = batch_info["status"]
if status == BatchStatus.COMPLETED:
return batch_info
elif status in [BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED]:
elif status in [
BatchStatus.FAILED,
BatchStatus.EXPIRED,
BatchStatus.CANCELLED,
]:
error_msg = f"Batch {batch_id} failed with status: {status}"
if batch_info.get("errors"):
error_msg += f". Errors: {batch_info['errors']}"
raise BatchError(error_msg, batch_id=batch_id, status=status)
# Check timeout
if timeout and (time.time() - start_time) > timeout:
raise BatchError(
f"Batch {batch_id} timed out after {timeout} seconds. Current status: {status}",
batch_id=batch_id,
status=status
status=status,
)
time.sleep(poll_interval)
def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]:
"""
Retrieve results from a completed batch.
Args:
batch_id: The batch ID to retrieve results for.
Returns:
List of result objects from the batch.
Raises:
BatchError: If batch is not completed or result retrieval fails.
"""
try:
batch_info = self.retrieve_batch(batch_id)
if batch_info["status"] != BatchStatus.COMPLETED:
raise BatchError(
f"Batch {batch_id} is not completed. Current status: {batch_info['status']}",
batch_id=batch_id,
status=batch_info["status"]
status=batch_info["status"],
)
output_file_id = batch_info.get("output_file_id")
if not output_file_id:
raise BatchError(f"No output file found for batch {batch_id}", batch_id=batch_id)
raise BatchError(
f"No output file found for batch {batch_id}", batch_id=batch_id
)
# Download and parse the results file
file_content = self.client.files.content(output_file_id)
results = []
for line in file_content.text.strip().split('\n'):
for line in file_content.text.strip().split("\n"):
if line.strip():
results.append(json.loads(line))
return results
except openai.OpenAIError as e:
raise BatchError(f"Failed to retrieve results for batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Failed to retrieve results for batch {batch_id}: {e}",
batch_id=batch_id,
) from e
except Exception as e:
raise BatchError(f"Unexpected error retrieving results for batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Unexpected error retrieving results for batch {batch_id}: {e}",
batch_id=batch_id,
) from e
def cancel_batch(self, batch_id: str) -> Dict[str, Any]:
"""
Cancel a batch job.
Args:
batch_id: The batch ID to cancel.
Returns:
Updated batch information after cancellation.
Raises:
BatchError: If batch cancellation fails.
"""
@@ -236,22 +256,26 @@ class OpenAIBatchClient:
batch = self.client.batches.cancel(batch_id)
return self.retrieve_batch(batch_id)
except openai.OpenAIError as e:
raise BatchError(f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id
) from e
except Exception as e:
raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e
raise BatchError(
f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id
) from e
class OpenAIBatchProcessor:
"""
High-level processor for managing OpenAI Batch API lifecycle with LangChain integration.
This class handles the complete batch processing workflow:
1. Converts LangChain messages to OpenAI batch format
2. Creates batch jobs using the OpenAI Batch API
3. Polls for completion with configurable intervals
4. Converts results back to LangChain format
"""
def __init__(
self,
client: openai.OpenAI,
@@ -261,7 +285,7 @@ class OpenAIBatchProcessor:
):
"""
Initialize the batch processor.
Args:
client: OpenAI client instance to use for API calls.
model: The model to use for batch requests.
@@ -272,7 +296,7 @@ class OpenAIBatchProcessor:
self.model = model
self.poll_interval = poll_interval
self.timeout = timeout
def create_batch(
self,
messages_list: List[List[BaseMessage]],
@@ -282,16 +306,16 @@ class OpenAIBatchProcessor:
) -> str:
"""
Create a batch job from a list of LangChain message sequences.
Args:
messages_list: List of message sequences to process in batch.
description: Optional description for the batch job.
metadata: Optional metadata to attach to the batch job.
**kwargs: Additional parameters to pass to chat completions.
Returns:
The batch ID for tracking the job.
Raises:
BatchError: If batch creation fails.
"""
@@ -300,19 +324,14 @@ class OpenAIBatchProcessor:
for i, messages in enumerate(messages_list):
custom_id = f"request_{i}_{uuid4().hex[:8]}"
request = create_batch_request(
messages=messages,
model=self.model,
custom_id=custom_id,
**kwargs
messages=messages, model=self.model, custom_id=custom_id, **kwargs
)
requests.append(request)
return self.batch_client.create_batch(
requests=requests,
description=description,
metadata=metadata,
requests=requests, description=description, metadata=metadata
)
def poll_batch_status(
self,
batch_id: str,
@@ -321,15 +340,15 @@ class OpenAIBatchProcessor:
) -> Dict[str, Any]:
"""
Poll batch status until completion or failure.
Args:
batch_id: The batch ID to poll.
poll_interval: Time in seconds between status checks. Uses default if None.
timeout: Maximum time in seconds to wait. Uses default if None.
Returns:
Final batch information when completed.
Raises:
BatchError: If batch fails or times out.
"""
@@ -338,26 +357,26 @@ class OpenAIBatchProcessor:
poll_interval=poll_interval or self.poll_interval,
timeout=timeout or self.timeout,
)
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
"""
Retrieve and convert batch results to LangChain format.
Args:
batch_id: The batch ID to retrieve results for.
Returns:
List of ChatResult objects corresponding to the original message sequences.
Raises:
BatchError: If batch is not completed or result retrieval fails.
"""
# Get raw results from OpenAI
raw_results = self.batch_client.retrieve_batch_results(batch_id)
# Sort results by custom_id to maintain order
raw_results.sort(key=lambda x: x.get("custom_id", ""))
# Convert to LangChain ChatResult format
chat_results = []
for result in raw_results:
@@ -365,39 +384,42 @@ class OpenAIBatchProcessor:
# Handle individual request errors
error_msg = f"Request failed: {result['error']}"
raise BatchError(error_msg, batch_id=batch_id)
response = result.get("response", {})
if not response:
raise BatchError(f"No response found in result: {result}", batch_id=batch_id)
raise BatchError(
f"No response found in result: {result}", batch_id=batch_id
)
body = response.get("body", {})
choices = body.get("choices", [])
if not choices:
raise BatchError(f"No choices found in response: {body}", batch_id=batch_id)
raise BatchError(
f"No choices found in response: {body}", batch_id=batch_id
)
# Convert OpenAI response to LangChain format
generations = []
for choice in choices:
message_dict = choice.get("message", {})
if not message_dict:
continue
# Convert OpenAI message dict to LangChain message
message = convert_dict_to_message(message_dict)
# Create ChatGeneration with metadata
generation_info = {
"finish_reason": choice.get("finish_reason"),
"logprobs": choice.get("logprobs"),
}
generation = ChatGeneration(
message=message,
generation_info=generation_info,
message=message, generation_info=generation_info
)
generations.append(generation)
# Create ChatResult with usage information
usage = body.get("usage", {})
llm_output = {
@@ -405,15 +427,12 @@ class OpenAIBatchProcessor:
"model_name": body.get("model"),
"system_fingerprint": body.get("system_fingerprint"),
}
chat_result = ChatResult(
generations=generations,
llm_output=llm_output,
)
chat_result = ChatResult(generations=generations, llm_output=llm_output)
chat_results.append(chat_result)
return chat_results
def process_batch(
self,
messages_list: List[List[BaseMessage]],
@@ -425,7 +444,7 @@ class OpenAIBatchProcessor:
) -> List[ChatResult]:
"""
Complete batch processing workflow: create, poll, and retrieve results.
Args:
messages_list: List of message sequences to process in batch.
description: Optional description for the batch job.
@@ -433,10 +452,10 @@ class OpenAIBatchProcessor:
poll_interval: Time in seconds between status checks. Uses default if None.
timeout: Maximum time in seconds to wait. Uses default if None.
**kwargs: Additional parameters to pass to chat completions.
Returns:
List of ChatResult objects corresponding to the original message sequences.
Raises:
BatchError: If any step of the batch processing fails.
"""
@@ -445,50 +464,39 @@ class OpenAIBatchProcessor:
messages_list=messages_list,
description=description,
metadata=metadata,
**kwargs
**kwargs,
)
# Poll until completion
self.poll_batch_status(
batch_id=batch_id,
poll_interval=poll_interval,
timeout=timeout,
batch_id=batch_id, poll_interval=poll_interval, timeout=timeout
)
# Retrieve and return results
return self.retrieve_batch_results(batch_id)
def create_batch_request(
messages: List[BaseMessage],
model: str,
custom_id: str,
**kwargs: Any,
messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any
) -> Dict[str, Any]:
"""
Create a batch request object from LangChain messages.
Args:
messages: List of LangChain messages to convert.
model: The model to use for the request.
custom_id: Unique identifier for this request within the batch.
**kwargs: Additional parameters to pass to the chat completion.
Returns:
Dictionary in OpenAI batch request format.
"""
# Convert LangChain messages to OpenAI format
openai_messages = [convert_message_to_dict(msg) for msg in messages]
return {
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": model,
"messages": openai_messages,
**kwargs
}
"body": {"model": model, "messages": openai_messages, **kwargs},
}

View File

@@ -6,20 +6,18 @@ They are designed to test the complete end-to-end batch processing workflow.
import os
import time
from typing import List
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.batch import BatchError
# Skip all tests if no API key is available
pytestmark = pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY not set, skipping integration tests"
reason="OPENAI_API_KEY not set, skipping integration tests",
)
@@ -31,7 +29,7 @@ class TestBatchAPIIntegration:
self.llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0.1, # Low temperature for consistent results
max_tokens=50, # Keep responses short for faster processing
max_tokens=50, # Keep responses short for faster processing
)
@pytest.mark.scheduled
@@ -40,7 +38,11 @@ class TestBatchAPIIntegration:
# Create a small batch of simple questions
messages_list = [
[HumanMessage(content="What is 2+2? Answer with just the number.")],
[HumanMessage(content="What is the capital of France? Answer with just the city name.")],
[
HumanMessage(
content="What is the capital of France? Answer with just the city name."
)
],
]
# Create batch job
@@ -58,19 +60,21 @@ class TestBatchAPIIntegration:
results = self.llm.batch_retrieve(
batch_id=batch_id,
poll_interval=30.0, # Poll every 30 seconds
timeout=1800.0, # 30 minute timeout
timeout=1800.0, # 30 minute timeout
)
# Verify results
assert len(results) == 2
assert all(isinstance(result, ChatResult) for result in results)
assert all(len(result.generations) == 1 for result in results)
assert all(isinstance(result.generations[0].message, AIMessage) for result in results)
assert all(
isinstance(result.generations[0].message, AIMessage) for result in results
)
# Check that we got reasonable responses
response1 = results[0].generations[0].message.content.strip()
response2 = results[1].generations[0].message.content.strip()
# Basic sanity checks (responses should contain expected content)
assert "4" in response1 or "four" in response1.lower()
assert "paris" in response2.lower()
@@ -85,10 +89,7 @@ class TestBatchAPIIntegration:
# Use batch API mode
results = self.llm.batch(
inputs,
use_batch_api=True,
poll_interval=30.0,
timeout=1800.0,
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
)
# Verify results
@@ -99,37 +100,32 @@ class TestBatchAPIIntegration:
# Basic sanity checks
response1 = results[0].content.strip().lower()
response2 = results[1].content.strip().lower()
assert any(char in response1 for char in ["1", "2", "3"])
assert "blue" in response2
@pytest.mark.scheduled
def test_batch_method_comparison(self):
"""Test that batch API and standard batch produce similar results."""
inputs = [
[HumanMessage(content="What is 1+1? Answer with just the number.")],
]
inputs = [[HumanMessage(content="What is 1+1? Answer with just the number.")]]
# Test standard batch processing
standard_results = self.llm.batch(inputs, use_batch_api=False)
# Test batch API processing
batch_api_results = self.llm.batch(
inputs,
use_batch_api=True,
poll_interval=30.0,
timeout=1800.0,
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
)
# Both should return similar structure
assert len(standard_results) == len(batch_api_results) == 1
assert isinstance(standard_results[0], AIMessage)
assert isinstance(batch_api_results[0], AIMessage)
# Both should contain reasonable answers
standard_content = standard_results[0].content.strip()
batch_content = batch_api_results[0].content.strip()
assert "2" in standard_content or "two" in standard_content.lower()
assert "2" in batch_content or "two" in batch_content.lower()
@@ -137,7 +133,7 @@ class TestBatchAPIIntegration:
def test_batch_with_different_parameters(self):
"""Test batch processing with different model parameters."""
messages_list = [
[HumanMessage(content="Write a haiku about coding. Keep it short.")],
[HumanMessage(content="Write a haiku about coding. Keep it short.")]
]
# Create batch with specific parameters
@@ -146,18 +142,16 @@ class TestBatchAPIIntegration:
description="Integration test - parameters",
metadata={"test_type": "parameters"},
temperature=0.8, # Higher temperature for creativity
max_tokens=100, # More tokens for haiku
max_tokens=100, # More tokens for haiku
)
results = self.llm.batch_retrieve(
batch_id=batch_id,
poll_interval=30.0,
timeout=1800.0,
batch_id=batch_id, poll_interval=30.0, timeout=1800.0
)
assert len(results) == 1
result_content = results[0].generations[0].message.content
# Should have some content (haiku)
assert len(result_content.strip()) > 10
# Haikus typically have line breaks
@@ -170,25 +164,24 @@ class TestBatchAPIIntegration:
messages_list = [
[
SystemMessage(content="You are a helpful math tutor. Answer concisely."),
SystemMessage(
content="You are a helpful math tutor. Answer concisely."
),
HumanMessage(content="What is 5 * 6?"),
],
]
]
batch_id = self.llm.batch_create(
messages_list=messages_list,
description="Integration test - system message",
messages_list=messages_list, description="Integration test - system message"
)
results = self.llm.batch_retrieve(
batch_id=batch_id,
poll_interval=30.0,
timeout=1800.0,
batch_id=batch_id, poll_interval=30.0, timeout=1800.0
)
assert len(results) == 1
result_content = results[0].generations[0].message.content.strip()
# Should contain the answer
assert "30" in result_content or "thirty" in result_content.lower()
@@ -196,14 +189,9 @@ class TestBatchAPIIntegration:
def test_batch_error_handling_invalid_model(self):
"""Test error handling with invalid model parameters."""
# Create a ChatOpenAI instance with an invalid model
invalid_llm = ChatOpenAI(
model="invalid-model-name-12345",
temperature=0.1,
)
invalid_llm = ChatOpenAI(model="invalid-model-name-12345", temperature=0.1)
messages_list = [
[HumanMessage(content="Hello")],
]
messages_list = [[HumanMessage(content="Hello")]]
# This should fail during batch creation or processing
with pytest.raises(BatchError):
@@ -216,23 +204,24 @@ class TestBatchAPIIntegration:
# Test with string inputs (should be converted to HumanMessage)
inputs = [
"What is the largest planet? Answer with just the planet name.",
[HumanMessage(content="What is the smallest planet? Answer with just the planet name.")],
[
HumanMessage(
content="What is the smallest planet? Answer with just the planet name."
)
],
]
results = self.llm.batch(
inputs,
use_batch_api=True,
poll_interval=30.0,
timeout=1800.0,
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
)
assert len(results) == 2
assert all(isinstance(result, AIMessage) for result in results)
# Check for reasonable responses
response1 = results[0].content.strip().lower()
response2 = results[1].content.strip().lower()
assert "jupiter" in response1
assert "mercury" in response2
@@ -248,12 +237,10 @@ class TestBatchAPIIntegration:
results = self.llm.batch_retrieve(batch_id, timeout=300.0)
assert results == []
@pytest.mark.scheduled
@pytest.mark.scheduled
def test_batch_metadata_preservation(self):
"""Test that batch metadata is properly handled."""
messages_list = [
[HumanMessage(content="Say 'test successful'")],
]
messages_list = [[HumanMessage(content="Say 'test successful'")]]
metadata = {
"test_name": "metadata_test",
@@ -270,7 +257,7 @@ class TestBatchAPIIntegration:
# Retrieve results
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
assert len(results) == 1
result_content = results[0].generations[0].message.content.strip().lower()
assert "test successful" in result_content
@@ -281,18 +268,12 @@ class TestBatchAPIEdgeCases:
def setup_method(self):
"""Set up test fixtures."""
self.llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0.1,
max_tokens=50,
)
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1, max_tokens=50)
@pytest.mark.scheduled
def test_batch_with_very_short_timeout(self):
"""Test batch processing with very short timeout."""
messages_list = [
[HumanMessage(content="Hello")],
]
messages_list = [[HumanMessage(content="Hello")]]
batch_id = self.llm.batch_create(messages_list=messages_list)
@@ -313,10 +294,8 @@ class TestBatchAPIEdgeCases:
def test_batch_with_long_content(self):
"""Test batch processing with longer content."""
long_content = "Please summarize this text: " + "This is a test sentence. " * 20
messages_list = [
[HumanMessage(content=long_content)],
]
messages_list = [[HumanMessage(content=long_content)]]
batch_id = self.llm.batch_create(
messages_list=messages_list,
@@ -324,10 +303,10 @@ class TestBatchAPIEdgeCases:
)
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
assert len(results) == 1
result_content = results[0].generations[0].message.content
# Should have some summary content
assert len(result_content.strip()) > 10
@@ -353,7 +332,7 @@ class TestBatchAPIPerformance:
]
start_time = time.time()
batch_id = self.llm.batch_create(
messages_list=messages_list,
description="Medium batch test - 10 requests",
@@ -363,7 +342,7 @@ class TestBatchAPIPerformance:
results = self.llm.batch_retrieve(
batch_id=batch_id,
poll_interval=60.0, # Poll every minute
timeout=3600.0, # 1 hour timeout
timeout=3600.0, # 1 hour timeout
)
end_time = time.time()
@@ -401,20 +380,17 @@ class TestBatchAPIPerformance:
# Test batch API processing time
start_batch = time.time()
batch_results = self.llm.batch(
messages,
use_batch_api=True,
poll_interval=30.0,
timeout=1800.0,
messages, use_batch_api=True, poll_interval=30.0, timeout=1800.0
)
batch_time = time.time() - start_batch
# Verify both produce results
assert len(sequential_results) == len(batch_results) == 2
# Log timing comparison
print(f"Sequential processing: {sequential_time:.2f}s")
print(f"Batch API processing: {batch_time:.2f}s")
# Note: Batch API will typically be slower for small batches due to polling,
# but should be more cost-effective for larger batches
@@ -424,6 +400,7 @@ def is_openai_api_available() -> bool:
"""Check if OpenAI API is available and accessible."""
try:
import openai
client = openai.OpenAI()
# Try a simple API call to verify connectivity
client.models.list()

View File

@@ -1,12 +1,10 @@
"""Test OpenAI Batch API functionality."""
import json
import time
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_openai import ChatOpenAI
@@ -42,7 +40,10 @@ class TestOpenAIBatchClient:
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
},
}
]
@@ -64,7 +65,10 @@ class TestOpenAIBatchClient:
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
},
}
]
@@ -76,10 +80,10 @@ class TestOpenAIBatchClient:
# Mock batch status progression
mock_batch_validating = MagicMock()
mock_batch_validating.status = "validating"
mock_batch_in_progress = MagicMock()
mock_batch_in_progress.status = "in_progress"
mock_batch_completed = MagicMock()
mock_batch_completed.status = "completed"
mock_batch_completed.output_file_id = "file_123"
@@ -134,13 +138,17 @@ class TestOpenAIBatchClient:
{
"message": {
"role": "assistant",
"content": "Hello! How can I help you?"
"content": "Hello! How can I help you?",
}
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
}
}
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18,
},
},
},
}
]
@@ -177,7 +185,7 @@ class TestOpenAIBatchProcessor:
def test_create_batch_success(self):
"""Test successful batch creation with message conversion."""
# Mock batch client
with patch.object(self.processor, 'batch_client') as mock_batch_client:
with patch.object(self.processor, "batch_client") as mock_batch_client:
mock_batch_client.create_batch.return_value = "batch_123"
messages_list = [
@@ -198,7 +206,7 @@ class TestOpenAIBatchProcessor:
# Verify batch requests were created correctly
call_args = mock_batch_client.create_batch.call_args
batch_requests = call_args[1]["batch_requests"]
assert len(batch_requests) == 2
assert batch_requests[0]["custom_id"] == "request-0"
assert batch_requests[0]["body"]["model"] == "gpt-3.5-turbo"
@@ -208,7 +216,7 @@ class TestOpenAIBatchProcessor:
def test_poll_batch_status_success(self):
"""Test successful batch status polling."""
with patch.object(self.processor, 'batch_client') as mock_batch_client:
with patch.object(self.processor, "batch_client") as mock_batch_client:
mock_batch = MagicMock()
mock_batch.status = "completed"
mock_batch_client.poll_batch_status.return_value = mock_batch
@@ -238,13 +246,17 @@ class TestOpenAIBatchProcessor:
{
"message": {
"role": "assistant",
"content": "2+2 equals 4."
"content": "2+2 equals 4.",
}
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
}
}
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18,
},
},
},
},
{
"id": "batch_req_124",
@@ -256,35 +268,42 @@ class TestOpenAIBatchProcessor:
{
"message": {
"role": "assistant",
"content": "The capital of France is Paris."
"content": "The capital of France is Paris.",
}
}
],
"usage": {"prompt_tokens": 12, "completion_tokens": 10, "total_tokens": 22}
}
}
}
"usage": {
"prompt_tokens": 12,
"completion_tokens": 10,
"total_tokens": 22,
},
},
},
},
]
with patch.object(self.processor, 'batch_client') as mock_batch_client:
with patch.object(self.processor, "batch_client") as mock_batch_client:
mock_batch_client.poll_batch_status.return_value = mock_batch
mock_batch_client.retrieve_batch_results.return_value = mock_results
chat_results = self.processor.retrieve_batch_results("batch_123")
assert len(chat_results) == 2
# Check first result
assert isinstance(chat_results[0], ChatResult)
assert len(chat_results[0].generations) == 1
assert isinstance(chat_results[0].generations[0].message, AIMessage)
assert chat_results[0].generations[0].message.content == "2+2 equals 4."
# Check second result
assert isinstance(chat_results[1], ChatResult)
assert len(chat_results[1].generations) == 1
assert isinstance(chat_results[1].generations[0].message, AIMessage)
assert chat_results[1].generations[0].message.content == "The capital of France is Paris."
assert (
chat_results[1].generations[0].message.content
== "The capital of France is Paris."
)
def test_retrieve_batch_results_with_errors(self):
"""Test batch result retrieval with some failed requests."""
@@ -303,13 +322,17 @@ class TestOpenAIBatchProcessor:
{
"message": {
"role": "assistant",
"content": "Success response"
"content": "Success response",
}
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
}
}
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18,
},
},
},
},
{
"id": "batch_req_124",
@@ -319,14 +342,14 @@ class TestOpenAIBatchProcessor:
"body": {
"error": {
"message": "Invalid request",
"type": "invalid_request_error"
"type": "invalid_request_error",
}
}
}
}
},
},
},
]
with patch.object(self.processor, 'batch_client') as mock_batch_client:
with patch.object(self.processor, "batch_client") as mock_batch_client:
mock_batch_client.poll_batch_status.return_value = mock_batch
mock_batch_client.retrieve_batch_results.return_value = mock_results
@@ -341,7 +364,7 @@ class TestBaseChatOpenAIBatchMethods:
"""Set up test fixtures."""
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_batch_create_success(self, mock_processor_class):
"""Test successful batch creation."""
mock_processor = MagicMock()
@@ -369,13 +392,21 @@ class TestBaseChatOpenAIBatchMethods:
temperature=0.7,
)
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_batch_retrieve_success(self, mock_processor_class):
"""Test successful batch result retrieval."""
mock_processor = MagicMock()
mock_chat_results = [
ChatResult(generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]),
ChatResult(generations=[ChatGeneration(message=AIMessage(content="The capital of France is Paris."))]),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]
),
ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content="The capital of France is Paris.")
)
]
),
]
mock_processor.retrieve_batch_results.return_value = mock_chat_results
mock_processor_class.return_value = mock_processor
@@ -384,22 +415,27 @@ class TestBaseChatOpenAIBatchMethods:
assert len(results) == 2
assert results[0].generations[0].message.content == "2+2 equals 4."
assert results[1].generations[0].message.content == "The capital of France is Paris."
assert (
results[1].generations[0].message.content
== "The capital of France is Paris."
)
mock_processor.poll_batch_status.assert_called_once_with(
batch_id="batch_123",
poll_interval=1.0,
timeout=60.0,
batch_id="batch_123", poll_interval=1.0, timeout=60.0
)
mock_processor.retrieve_batch_results.assert_called_once_with("batch_123")
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_batch_method_with_batch_api_true(self, mock_processor_class):
"""Test batch method with use_batch_api=True."""
mock_processor = MagicMock()
mock_chat_results = [
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]),
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response 1"))]
),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response 2"))]
),
]
mock_processor.create_batch.return_value = "batch_123"
mock_processor.retrieve_batch_results.return_value = mock_chat_results
@@ -429,7 +465,7 @@ class TestBaseChatOpenAIBatchMethods:
]
# Mock the parent class batch method
with patch.object(ChatOpenAI.__bases__[0], 'batch') as mock_super_batch:
with patch.object(ChatOpenAI.__bases__[0], "batch") as mock_super_batch:
mock_super_batch.return_value = [
AIMessage(content="Response 1"),
AIMessage(content="Response 2"),
@@ -439,9 +475,7 @@ class TestBaseChatOpenAIBatchMethods:
assert len(results) == 2
mock_super_batch.assert_called_once_with(
inputs=inputs,
config=None,
return_exceptions=False,
inputs=inputs, config=None, return_exceptions=False
)
def test_convert_input_to_messages_list(self):
@@ -462,7 +496,7 @@ class TestBaseChatOpenAIBatchMethods:
assert isinstance(result[0], HumanMessage)
assert result[0].content == "Hello"
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_batch_create_with_error_handling(self, mock_processor_class):
"""Test batch creation with error handling."""
mock_processor = MagicMock()
@@ -474,11 +508,13 @@ class TestBaseChatOpenAIBatchMethods:
with pytest.raises(BatchError, match="Batch creation failed"):
self.llm.batch_create(messages_list)
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_batch_retrieve_with_error_handling(self, mock_processor_class):
"""Test batch retrieval with error handling."""
mock_processor = MagicMock()
mock_processor.poll_batch_status.side_effect = BatchError("Batch polling failed")
mock_processor.poll_batch_status.side_effect = BatchError(
"Batch polling failed"
)
mock_processor_class.return_value = mock_processor
with pytest.raises(BatchError, match="Batch polling failed"):
@@ -486,12 +522,15 @@ class TestBaseChatOpenAIBatchMethods:
def test_batch_method_input_conversion(self):
"""Test batch method handles various input formats correctly."""
with patch.object(self.llm, 'batch_create') as mock_create, \
patch.object(self.llm, 'batch_retrieve') as mock_retrieve:
with (
patch.object(self.llm, "batch_create") as mock_create,
patch.object(self.llm, "batch_retrieve") as mock_retrieve,
):
mock_create.return_value = "batch_123"
mock_retrieve.return_value = [
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response"))]),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response"))]
)
]
# Test with string inputs
@@ -501,8 +540,8 @@ class TestBaseChatOpenAIBatchMethods:
# Verify conversion happened
mock_create.assert_called_once()
call_args = mock_create.call_args[1]
messages_list = call_args['messages_list']
messages_list = call_args["messages_list"]
assert len(messages_list) == 1
assert len(messages_list[0]) == 1
assert isinstance(messages_list[0][0], HumanMessage)
@@ -532,7 +571,7 @@ class TestBatchIntegrationScenarios:
"""Set up test fixtures."""
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_empty_messages_list(self, mock_processor_class):
"""Test handling of empty messages list."""
mock_processor = MagicMock()
@@ -543,16 +582,18 @@ class TestBatchIntegrationScenarios:
results = self.llm.batch([], use_batch_api=True)
assert results == []
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_large_batch_processing(self, mock_processor_class):
"""Test processing of large batch."""
mock_processor = MagicMock()
mock_processor.create_batch.return_value = "batch_123"
# Create mock results for large batch
num_requests = 100
mock_chat_results = [
ChatResult(generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))])
ChatResult(
generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))]
)
for i in range(num_requests)
]
mock_processor.retrieve_batch_results.return_value = mock_chat_results
@@ -565,14 +606,18 @@ class TestBatchIntegrationScenarios:
for i, result in enumerate(results):
assert result.content == f"Response {i}"
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
def test_mixed_message_types(self, mock_processor_class):
"""Test batch processing with mixed message types."""
mock_processor = MagicMock()
mock_processor.create_batch.return_value = "batch_123"
mock_processor.retrieve_batch_results.return_value = [
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]),
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response 1"))]
),
ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response 2"))]
),
]
mock_processor_class.return_value = mock_processor
@@ -587,12 +632,12 @@ class TestBatchIntegrationScenarios:
# Verify the conversion happened correctly
mock_processor.create_batch.assert_called_once()
call_args = mock_processor.create_batch.call_args[1]
messages_list = call_args['messages_list']
messages_list = call_args["messages_list"]
# First input should be converted to HumanMessage
assert isinstance(messages_list[0][0], HumanMessage)
assert messages_list[0][0].content == "String input"
# Second input should remain as is
assert isinstance(messages_list[1][0], HumanMessage)
assert messages_list[1][0].content == "Direct message list"