mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 05:47:54 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, BatchError, BatchStatus
|
||||
from langchain_openai.chat_models import (
|
||||
AzureChatOpenAI,
|
||||
BatchError,
|
||||
BatchStatus,
|
||||
ChatOpenAI,
|
||||
)
|
||||
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||
from langchain_openai.tools import custom_tool
|
||||
@@ -14,4 +19,3 @@ __all__ = [
|
||||
"BatchError",
|
||||
"BatchStatus",
|
||||
]
|
||||
|
||||
|
||||
@@ -3,4 +3,3 @@ from langchain_openai.chat_models.base import ChatOpenAI
|
||||
from langchain_openai.chat_models.batch import BatchError, BatchStatus
|
||||
|
||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "BatchError", "BatchStatus"]
|
||||
|
||||
|
||||
@@ -1503,142 +1503,142 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> int:
|
||||
"""Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package.
|
||||
|
||||
**Requirements**: You must have the ``pillow`` installed if you want to count
|
||||
image tokens if you are specifying the image as a base64 string, and you must
|
||||
have both ``pillow`` and ``httpx`` installed if you are specifying the image
|
||||
as a URL. If these aren't installed image inputs will be ignored in token
|
||||
counting.
|
||||
**Requirements**: You must have the ``pillow`` installed if you want to count
|
||||
image tokens if you are specifying the image as a base64 string, and you must
|
||||
have both ``pillow`` and ``httpx`` installed if you are specifying the image
|
||||
as a URL. If these aren't installed image inputs will be ignored in token
|
||||
counting.
|
||||
|
||||
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
|
||||
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
.. dropdown:: Batch API for cost savings
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
.. dropdown:: Batch API for cost savings
|
||||
|
||||
.. versionadded:: 0.3.7
|
||||
.. versionadded:: 0.3.7
|
||||
|
||||
OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by
|
||||
processing requests asynchronously. This is ideal for tasks like data processing,
|
||||
content generation, or evaluation that don't require immediate responses.
|
||||
OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by
|
||||
processing requests asynchronously. This is ideal for tasks like data processing,
|
||||
content generation, or evaluation that don't require immediate responses.
|
||||
|
||||
**Cost vs Latency Tradeoff:**
|
||||
**Cost vs Latency Tradeoff:**
|
||||
|
||||
- **Standard API**: Immediate results, full pricing
|
||||
- **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours)
|
||||
- **Standard API**: Immediate results, full pricing
|
||||
- **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours)
|
||||
|
||||
**Method 1: Direct batch management**
|
||||
**Method 1: Direct batch management**
|
||||
|
||||
Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle:
|
||||
Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
|
||||
# Prepare multiple message sequences for batch processing
|
||||
messages_list = [
|
||||
[HumanMessage(content="Translate 'hello' to French")],
|
||||
[HumanMessage(content="Translate 'goodbye' to Spanish")],
|
||||
[HumanMessage(content="What is the capital of Italy?")],
|
||||
]
|
||||
# Prepare multiple message sequences for batch processing
|
||||
messages_list = [
|
||||
[HumanMessage(content="Translate 'hello' to French")],
|
||||
[HumanMessage(content="Translate 'goodbye' to Spanish")],
|
||||
[HumanMessage(content="What is the capital of Italy?")],
|
||||
]
|
||||
|
||||
# Create batch job (returns immediately with batch ID)
|
||||
batch_id = llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Translation and geography batch",
|
||||
metadata={"project": "multilingual_qa", "user": "analyst_1"},
|
||||
)
|
||||
print(f"Batch created: {batch_id}")
|
||||
# Create batch job (returns immediately with batch ID)
|
||||
batch_id = llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Translation and geography batch",
|
||||
metadata={"project": "multilingual_qa", "user": "analyst_1"},
|
||||
)
|
||||
print(f"Batch created: {batch_id}")
|
||||
|
||||
# Later, retrieve results (polls until completion)
|
||||
results = llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=60.0, # Check every minute
|
||||
timeout=3600.0, # 1 hour timeout
|
||||
)
|
||||
# Later, retrieve results (polls until completion)
|
||||
results = llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=60.0, # Check every minute
|
||||
timeout=3600.0, # 1 hour timeout
|
||||
)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(results):
|
||||
response = result.generations[0].message.content
|
||||
print(f"Response {i+1}: {response}")
|
||||
# Process results
|
||||
for i, result in enumerate(results):
|
||||
response = result.generations[0].message.content
|
||||
print(f"Response {i + 1}: {response}")
|
||||
|
||||
**Method 2: Enhanced batch() method**
|
||||
**Method 2: Enhanced batch() method**
|
||||
|
||||
Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration:
|
||||
Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Standard batch processing (immediate, full cost)
|
||||
inputs = [
|
||||
[HumanMessage(content="What is 2+2?")],
|
||||
[HumanMessage(content="What is 3+3?")],
|
||||
]
|
||||
standard_results = llm.batch(inputs) # Default: use_batch_api=False
|
||||
# Standard batch processing (immediate, full cost)
|
||||
inputs = [
|
||||
[HumanMessage(content="What is 2+2?")],
|
||||
[HumanMessage(content="What is 3+3?")],
|
||||
]
|
||||
standard_results = llm.batch(inputs) # Default: use_batch_api=False
|
||||
|
||||
# Batch API processing (50% cost savings, polling)
|
||||
batch_results = llm.batch(
|
||||
inputs,
|
||||
use_batch_api=True, # Enable cost savings
|
||||
poll_interval=30.0, # Poll every 30 seconds
|
||||
timeout=1800.0, # 30 minute timeout
|
||||
)
|
||||
# Batch API processing (50% cost savings, polling)
|
||||
batch_results = llm.batch(
|
||||
inputs,
|
||||
use_batch_api=True, # Enable cost savings
|
||||
poll_interval=30.0, # Poll every 30 seconds
|
||||
timeout=1800.0, # 30 minute timeout
|
||||
)
|
||||
|
||||
**Batch creation with custom parameters:**
|
||||
**Batch creation with custom parameters:**
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Create batch with specific model parameters
|
||||
batch_id = llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Creative writing batch",
|
||||
metadata={"task_type": "content_generation"},
|
||||
temperature=0.8, # Higher creativity
|
||||
max_tokens=200, # Longer responses
|
||||
top_p=0.9, # Nucleus sampling
|
||||
)
|
||||
# Create batch with specific model parameters
|
||||
batch_id = llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Creative writing batch",
|
||||
metadata={"task_type": "content_generation"},
|
||||
temperature=0.8, # Higher creativity
|
||||
max_tokens=200, # Longer responses
|
||||
top_p=0.9, # Nucleus sampling
|
||||
)
|
||||
|
||||
**Error handling and monitoring:**
|
||||
**Error handling and monitoring:**
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai.chat_models.batch import BatchError
|
||||
from langchain_openai.chat_models.batch import BatchError
|
||||
|
||||
try:
|
||||
batch_id = llm.batch_create(messages_list)
|
||||
results = llm.batch_retrieve(batch_id, timeout=600.0)
|
||||
except BatchError as e:
|
||||
print(f"Batch processing failed: {e}")
|
||||
# Handle batch failure (retry, fallback to standard API, etc.)
|
||||
try:
|
||||
batch_id = llm.batch_create(messages_list)
|
||||
results = llm.batch_retrieve(batch_id, timeout=600.0)
|
||||
except BatchError as e:
|
||||
print(f"Batch processing failed: {e}")
|
||||
# Handle batch failure (retry, fallback to standard API, etc.)
|
||||
|
||||
**Best practices:**
|
||||
**Best practices:**
|
||||
|
||||
- Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times
|
||||
- Set appropriate **timeouts** based on batch size (larger batches take longer)
|
||||
- Include **descriptive metadata** for tracking and debugging batch jobs
|
||||
- Consider **fallback strategies** for time-sensitive applications
|
||||
- Monitor batch status for **long-running jobs** to detect failures early
|
||||
- Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times
|
||||
- Set appropriate **timeouts** based on batch size (larger batches take longer)
|
||||
- Include **descriptive metadata** for tracking and debugging batch jobs
|
||||
- Consider **fallback strategies** for time-sensitive applications
|
||||
- Monitor batch status for **long-running jobs** to detect failures early
|
||||
|
||||
**When to use Batch API:**
|
||||
**When to use Batch API:**
|
||||
|
||||
✅ **Good for:**
|
||||
- Data processing and analysis
|
||||
- Content generation at scale
|
||||
- Model evaluation and testing
|
||||
- Batch translation or summarization
|
||||
- Non-interactive applications
|
||||
✅ **Good for:**
|
||||
- Data processing and analysis
|
||||
- Content generation at scale
|
||||
- Model evaluation and testing
|
||||
- Batch translation or summarization
|
||||
- Non-interactive applications
|
||||
|
||||
❌ **Not suitable for:**
|
||||
- Real-time chat applications
|
||||
- Interactive user interfaces
|
||||
- Time-critical decision making
|
||||
- Applications requiring immediate responses
|
||||
|
||||
❌ **Not suitable for:**
|
||||
- Real-time chat applications
|
||||
- Interactive user interfaces
|
||||
- Time-critical decision making
|
||||
- Applications requiring immediate responses
|
||||
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
# TODO: Count bound tools as part of input.
|
||||
if tools is not None:
|
||||
@@ -2147,6 +2147,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
def batch_create(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
@@ -2159,10 +2160,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> str:
|
||||
"""
|
||||
Create a batch job using OpenAI's Batch API for asynchronous processing.
|
||||
|
||||
|
||||
This method provides 50% cost savings compared to the standard API in exchange
|
||||
for asynchronous processing with polling for results.
|
||||
|
||||
|
||||
Args:
|
||||
messages_list: List of message sequences to process in batch.
|
||||
description: Optional description for the batch job.
|
||||
@@ -2170,34 +2171,34 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
poll_interval: Default time in seconds between status checks when polling.
|
||||
timeout: Default maximum time in seconds to wait for completion.
|
||||
**kwargs: Additional parameters to pass to chat completions.
|
||||
|
||||
|
||||
Returns:
|
||||
The batch ID for tracking the asynchronous job.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch creation fails.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
llm = ChatOpenAI()
|
||||
messages_list = [
|
||||
[HumanMessage(content="What is 2+2?")],
|
||||
[HumanMessage(content="What is the capital of France?")],
|
||||
]
|
||||
|
||||
|
||||
# Create batch job (50% cost savings)
|
||||
batch_id = llm.batch_create(messages_list)
|
||||
|
||||
|
||||
# Later, retrieve results
|
||||
results = llm.batch_retrieve(batch_id)
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
|
||||
|
||||
|
||||
# Create batch processor with current model settings
|
||||
processor = OpenAIBatchProcessor(
|
||||
client=self.root_client,
|
||||
@@ -2205,19 +2206,19 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
poll_interval=poll_interval,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
# Filter and prepare kwargs for batch processing
|
||||
batch_kwargs = self._get_invocation_params(**kwargs)
|
||||
# Remove model from kwargs since it's handled by the processor
|
||||
batch_kwargs.pop("model", None)
|
||||
|
||||
|
||||
return processor.create_batch(
|
||||
messages_list=messages_list,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
**batch_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def batch_retrieve(
|
||||
self,
|
||||
batch_id: str,
|
||||
@@ -2227,36 +2228,36 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> List[ChatResult]:
|
||||
"""
|
||||
Retrieve results from a batch job, polling until completion if necessary.
|
||||
|
||||
|
||||
This method will poll the batch status until completion and return the results
|
||||
converted to LangChain ChatResult format.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID returned from batch_create().
|
||||
poll_interval: Time in seconds between status checks. Uses default if None.
|
||||
timeout: Maximum time in seconds to wait. Uses default if None.
|
||||
|
||||
|
||||
Returns:
|
||||
List of ChatResult objects corresponding to the original message sequences.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch retrieval fails, times out, or batch job failed.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
# After creating a batch job
|
||||
batch_id = llm.batch_create(messages_list)
|
||||
|
||||
|
||||
# Retrieve results (will poll until completion)
|
||||
results = llm.batch_retrieve(batch_id)
|
||||
|
||||
|
||||
for result in results:
|
||||
print(result.generations[0].message.content)
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
|
||||
|
||||
|
||||
# Create batch processor with current model settings
|
||||
processor = OpenAIBatchProcessor(
|
||||
client=self.root_client,
|
||||
@@ -2264,15 +2265,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
poll_interval=poll_interval or 10.0,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
# Poll for completion and retrieve results
|
||||
processor.poll_batch_status(
|
||||
batch_id=batch_id,
|
||||
poll_interval=poll_interval,
|
||||
timeout=timeout,
|
||||
batch_id=batch_id, poll_interval=poll_interval, timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
return processor.retrieve_batch_results(batch_id)
|
||||
|
||||
@override
|
||||
def batch(
|
||||
self,
|
||||
@@ -2285,11 +2285,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
Batch process multiple inputs using either standard API or OpenAI Batch API.
|
||||
|
||||
|
||||
This method provides two processing modes:
|
||||
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
|
||||
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
|
||||
|
||||
|
||||
Args:
|
||||
inputs: List of inputs to process in batch.
|
||||
config: Configuration for the batch processing.
|
||||
@@ -2297,33 +2297,33 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
use_batch_api: If True, use OpenAI's Batch API for cost savings with polling.
|
||||
If False (default), use standard parallel processing for immediate results.
|
||||
**kwargs: Additional parameters to pass to the underlying model.
|
||||
|
||||
|
||||
Returns:
|
||||
List of BaseMessage objects corresponding to the inputs.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch processing fails (when use_batch_api=True).
|
||||
|
||||
|
||||
Note:
|
||||
**Cost vs Latency Tradeoff:**
|
||||
- use_batch_api=False: Immediate results, standard API pricing
|
||||
- use_batch_api=True: 50% cost savings, asynchronous processing with polling
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
llm = ChatOpenAI()
|
||||
inputs = [
|
||||
[HumanMessage(content="What is 2+2?")],
|
||||
[HumanMessage(content="What is the capital of France?")],
|
||||
]
|
||||
|
||||
|
||||
# Standard processing (immediate results)
|
||||
results = llm.batch(inputs)
|
||||
|
||||
|
||||
# Batch API processing (50% cost savings, polling required)
|
||||
results = llm.batch(inputs, use_batch_api=True)
|
||||
"""
|
||||
@@ -2338,11 +2338,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
# Convert single input to list of messages
|
||||
messages = self._convert_input_to_messages(input_item)
|
||||
messages_list.append(messages)
|
||||
|
||||
|
||||
# Create batch job and poll for results
|
||||
batch_id = self.batch_create(messages_list, **kwargs)
|
||||
chat_results = self.batch_retrieve(batch_id)
|
||||
|
||||
|
||||
# Convert ChatResult objects to BaseMessage objects
|
||||
return [result.generations[0].message for result in chat_results]
|
||||
else:
|
||||
@@ -2354,7 +2354,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _convert_input_to_messages(self, input_item: LanguageModelInput) -> List[BaseMessage]:
|
||||
def _convert_input_to_messages(
|
||||
self, input_item: LanguageModelInput
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert various input formats to a list of BaseMessage objects."""
|
||||
if isinstance(input_item, list):
|
||||
# Already a list of messages
|
||||
@@ -2365,27 +2367,17 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif isinstance(input_item, str):
|
||||
# String input - convert to HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=input_item)]
|
||||
elif hasattr(input_item, 'to_messages'):
|
||||
elif hasattr(input_item, "to_messages"):
|
||||
# PromptValue or similar
|
||||
return input_item.to_messages()
|
||||
else:
|
||||
# Try to convert to string and then to HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=str(input_item))]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_generation_chunk_from_completion(
|
||||
self, completion: openai.BaseModel
|
||||
) -> ChatGenerationChunk:
|
||||
|
||||
@@ -5,19 +5,22 @@ from __future__ import annotations
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain_openai.chat_models._compat import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain_openai.chat_models._compat import (
|
||||
convert_dict_to_message,
|
||||
convert_message_to_dict,
|
||||
)
|
||||
|
||||
|
||||
class BatchStatus(str, Enum):
|
||||
"""OpenAI Batch API status values."""
|
||||
|
||||
|
||||
VALIDATING = "validating"
|
||||
FAILED = "failed"
|
||||
IN_PROGRESS = "in_progress"
|
||||
@@ -30,8 +33,10 @@ class BatchStatus(str, Enum):
|
||||
|
||||
class BatchError(Exception):
|
||||
"""Exception raised when batch processing fails."""
|
||||
|
||||
def __init__(self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None):
|
||||
|
||||
def __init__(
|
||||
self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.batch_id = batch_id
|
||||
self.status = status
|
||||
@@ -39,22 +44,22 @@ class BatchError(Exception):
|
||||
|
||||
class OpenAIBatchClient:
|
||||
"""
|
||||
OpenAI Batch API client wrapper that handles batch creation, status polling,
|
||||
OpenAI Batch API client wrapper that handles batch creation, status polling,
|
||||
and result retrieval.
|
||||
|
||||
|
||||
This class provides a high-level interface to OpenAI's Batch API, which offers
|
||||
50% cost savings compared to the standard API in exchange for asynchronous processing.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, client: openai.OpenAI):
|
||||
"""
|
||||
Initialize the batch client.
|
||||
|
||||
|
||||
Args:
|
||||
client: OpenAI client instance to use for API calls.
|
||||
"""
|
||||
self.client = client
|
||||
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
requests: List[Dict[str, Any]],
|
||||
@@ -63,53 +68,52 @@ class OpenAIBatchClient:
|
||||
) -> str:
|
||||
"""
|
||||
Create a new batch job with the OpenAI Batch API.
|
||||
|
||||
|
||||
Args:
|
||||
requests: List of request objects in OpenAI batch format.
|
||||
description: Optional description for the batch job.
|
||||
metadata: Optional metadata to attach to the batch job.
|
||||
|
||||
|
||||
Returns:
|
||||
The batch ID for tracking the job.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch creation fails.
|
||||
"""
|
||||
try:
|
||||
# Create JSONL content for the batch
|
||||
jsonl_content = "\n".join(json.dumps(req) for req in requests)
|
||||
|
||||
|
||||
# Upload the batch file
|
||||
file_response = self.client.files.create(
|
||||
file=jsonl_content.encode('utf-8'),
|
||||
purpose="batch"
|
||||
file=jsonl_content.encode("utf-8"), purpose="batch"
|
||||
)
|
||||
|
||||
|
||||
# Create the batch job
|
||||
batch_response = self.client.batches.create(
|
||||
input_file_id=file_response.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata=metadata or {}
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
return batch_response.id
|
||||
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to create batch: {e}") from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error creating batch: {e}") from e
|
||||
|
||||
|
||||
def retrieve_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve batch information by ID.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing batch information including status.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch retrieval fails.
|
||||
"""
|
||||
@@ -119,20 +123,24 @@ class OpenAIBatchClient:
|
||||
"id": batch.id,
|
||||
"status": batch.status,
|
||||
"created_at": batch.created_at,
|
||||
"completed_at": getattr(batch, 'completed_at', None),
|
||||
"failed_at": getattr(batch, 'failed_at', None),
|
||||
"expired_at": getattr(batch, 'expired_at', None),
|
||||
"request_counts": getattr(batch, 'request_counts', {}),
|
||||
"metadata": getattr(batch, 'metadata', {}),
|
||||
"errors": getattr(batch, 'errors', None),
|
||||
"output_file_id": getattr(batch, 'output_file_id', None),
|
||||
"error_file_id": getattr(batch, 'error_file_id', None),
|
||||
"completed_at": getattr(batch, "completed_at", None),
|
||||
"failed_at": getattr(batch, "failed_at", None),
|
||||
"expired_at": getattr(batch, "expired_at", None),
|
||||
"request_counts": getattr(batch, "request_counts", {}),
|
||||
"metadata": getattr(batch, "metadata", {}),
|
||||
"errors": getattr(batch, "errors", None),
|
||||
"output_file_id": getattr(batch, "output_file_id", None),
|
||||
"error_file_id": getattr(batch, "error_file_id", None),
|
||||
}
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
raise BatchError(
|
||||
f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
raise BatchError(
|
||||
f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id
|
||||
) from e
|
||||
|
||||
def poll_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
@@ -141,94 +149,106 @@ class OpenAIBatchClient:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to poll.
|
||||
poll_interval: Time in seconds between status checks.
|
||||
timeout: Maximum time in seconds to wait. None for no timeout.
|
||||
|
||||
|
||||
Returns:
|
||||
Final batch information when completed.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch fails or times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while True:
|
||||
batch_info = self.retrieve_batch(batch_id)
|
||||
status = batch_info["status"]
|
||||
|
||||
|
||||
if status == BatchStatus.COMPLETED:
|
||||
return batch_info
|
||||
elif status in [BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED]:
|
||||
elif status in [
|
||||
BatchStatus.FAILED,
|
||||
BatchStatus.EXPIRED,
|
||||
BatchStatus.CANCELLED,
|
||||
]:
|
||||
error_msg = f"Batch {batch_id} failed with status: {status}"
|
||||
if batch_info.get("errors"):
|
||||
error_msg += f". Errors: {batch_info['errors']}"
|
||||
raise BatchError(error_msg, batch_id=batch_id, status=status)
|
||||
|
||||
|
||||
# Check timeout
|
||||
if timeout and (time.time() - start_time) > timeout:
|
||||
raise BatchError(
|
||||
f"Batch {batch_id} timed out after {timeout} seconds. Current status: {status}",
|
||||
batch_id=batch_id,
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve results from a completed batch.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve results for.
|
||||
|
||||
|
||||
Returns:
|
||||
List of result objects from the batch.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch is not completed or result retrieval fails.
|
||||
"""
|
||||
try:
|
||||
batch_info = self.retrieve_batch(batch_id)
|
||||
|
||||
|
||||
if batch_info["status"] != BatchStatus.COMPLETED:
|
||||
raise BatchError(
|
||||
f"Batch {batch_id} is not completed. Current status: {batch_info['status']}",
|
||||
batch_id=batch_id,
|
||||
status=batch_info["status"]
|
||||
status=batch_info["status"],
|
||||
)
|
||||
|
||||
|
||||
output_file_id = batch_info.get("output_file_id")
|
||||
if not output_file_id:
|
||||
raise BatchError(f"No output file found for batch {batch_id}", batch_id=batch_id)
|
||||
|
||||
raise BatchError(
|
||||
f"No output file found for batch {batch_id}", batch_id=batch_id
|
||||
)
|
||||
|
||||
# Download and parse the results file
|
||||
file_content = self.client.files.content(output_file_id)
|
||||
results = []
|
||||
|
||||
for line in file_content.text.strip().split('\n'):
|
||||
|
||||
for line in file_content.text.strip().split("\n"):
|
||||
if line.strip():
|
||||
results.append(json.loads(line))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to retrieve results for batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
raise BatchError(
|
||||
f"Failed to retrieve results for batch {batch_id}: {e}",
|
||||
batch_id=batch_id,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error retrieving results for batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
raise BatchError(
|
||||
f"Unexpected error retrieving results for batch {batch_id}: {e}",
|
||||
batch_id=batch_id,
|
||||
) from e
|
||||
|
||||
def cancel_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Cancel a batch job.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to cancel.
|
||||
|
||||
|
||||
Returns:
|
||||
Updated batch information after cancellation.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch cancellation fails.
|
||||
"""
|
||||
@@ -236,22 +256,26 @@ class OpenAIBatchClient:
|
||||
batch = self.client.batches.cancel(batch_id)
|
||||
return self.retrieve_batch(batch_id)
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
raise BatchError(
|
||||
f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
raise BatchError(
|
||||
f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id
|
||||
) from e
|
||||
|
||||
|
||||
class OpenAIBatchProcessor:
|
||||
"""
|
||||
High-level processor for managing OpenAI Batch API lifecycle with LangChain integration.
|
||||
|
||||
|
||||
This class handles the complete batch processing workflow:
|
||||
1. Converts LangChain messages to OpenAI batch format
|
||||
2. Creates batch jobs using the OpenAI Batch API
|
||||
3. Polls for completion with configurable intervals
|
||||
4. Converts results back to LangChain format
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.OpenAI,
|
||||
@@ -261,7 +285,7 @@ class OpenAIBatchProcessor:
|
||||
):
|
||||
"""
|
||||
Initialize the batch processor.
|
||||
|
||||
|
||||
Args:
|
||||
client: OpenAI client instance to use for API calls.
|
||||
model: The model to use for batch requests.
|
||||
@@ -272,7 +296,7 @@ class OpenAIBatchProcessor:
|
||||
self.model = model
|
||||
self.poll_interval = poll_interval
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
@@ -282,16 +306,16 @@ class OpenAIBatchProcessor:
|
||||
) -> str:
|
||||
"""
|
||||
Create a batch job from a list of LangChain message sequences.
|
||||
|
||||
|
||||
Args:
|
||||
messages_list: List of message sequences to process in batch.
|
||||
description: Optional description for the batch job.
|
||||
metadata: Optional metadata to attach to the batch job.
|
||||
**kwargs: Additional parameters to pass to chat completions.
|
||||
|
||||
|
||||
Returns:
|
||||
The batch ID for tracking the job.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch creation fails.
|
||||
"""
|
||||
@@ -300,19 +324,14 @@ class OpenAIBatchProcessor:
|
||||
for i, messages in enumerate(messages_list):
|
||||
custom_id = f"request_{i}_{uuid4().hex[:8]}"
|
||||
request = create_batch_request(
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
custom_id=custom_id,
|
||||
**kwargs
|
||||
messages=messages, model=self.model, custom_id=custom_id, **kwargs
|
||||
)
|
||||
requests.append(request)
|
||||
|
||||
|
||||
return self.batch_client.create_batch(
|
||||
requests=requests,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
requests=requests, description=description, metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
def poll_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
@@ -321,15 +340,15 @@ class OpenAIBatchProcessor:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to poll.
|
||||
poll_interval: Time in seconds between status checks. Uses default if None.
|
||||
timeout: Maximum time in seconds to wait. Uses default if None.
|
||||
|
||||
|
||||
Returns:
|
||||
Final batch information when completed.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch fails or times out.
|
||||
"""
|
||||
@@ -338,26 +357,26 @@ class OpenAIBatchProcessor:
|
||||
poll_interval=poll_interval or self.poll_interval,
|
||||
timeout=timeout or self.timeout,
|
||||
)
|
||||
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
|
||||
"""
|
||||
Retrieve and convert batch results to LangChain format.
|
||||
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve results for.
|
||||
|
||||
|
||||
Returns:
|
||||
List of ChatResult objects corresponding to the original message sequences.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If batch is not completed or result retrieval fails.
|
||||
"""
|
||||
# Get raw results from OpenAI
|
||||
raw_results = self.batch_client.retrieve_batch_results(batch_id)
|
||||
|
||||
|
||||
# Sort results by custom_id to maintain order
|
||||
raw_results.sort(key=lambda x: x.get("custom_id", ""))
|
||||
|
||||
|
||||
# Convert to LangChain ChatResult format
|
||||
chat_results = []
|
||||
for result in raw_results:
|
||||
@@ -365,39 +384,42 @@ class OpenAIBatchProcessor:
|
||||
# Handle individual request errors
|
||||
error_msg = f"Request failed: {result['error']}"
|
||||
raise BatchError(error_msg, batch_id=batch_id)
|
||||
|
||||
|
||||
response = result.get("response", {})
|
||||
if not response:
|
||||
raise BatchError(f"No response found in result: {result}", batch_id=batch_id)
|
||||
|
||||
raise BatchError(
|
||||
f"No response found in result: {result}", batch_id=batch_id
|
||||
)
|
||||
|
||||
body = response.get("body", {})
|
||||
choices = body.get("choices", [])
|
||||
|
||||
|
||||
if not choices:
|
||||
raise BatchError(f"No choices found in response: {body}", batch_id=batch_id)
|
||||
|
||||
raise BatchError(
|
||||
f"No choices found in response: {body}", batch_id=batch_id
|
||||
)
|
||||
|
||||
# Convert OpenAI response to LangChain format
|
||||
generations = []
|
||||
for choice in choices:
|
||||
message_dict = choice.get("message", {})
|
||||
if not message_dict:
|
||||
continue
|
||||
|
||||
|
||||
# Convert OpenAI message dict to LangChain message
|
||||
message = convert_dict_to_message(message_dict)
|
||||
|
||||
|
||||
# Create ChatGeneration with metadata
|
||||
generation_info = {
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
"logprobs": choice.get("logprobs"),
|
||||
}
|
||||
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=generation_info,
|
||||
message=message, generation_info=generation_info
|
||||
)
|
||||
generations.append(generation)
|
||||
|
||||
|
||||
# Create ChatResult with usage information
|
||||
usage = body.get("usage", {})
|
||||
llm_output = {
|
||||
@@ -405,15 +427,12 @@ class OpenAIBatchProcessor:
|
||||
"model_name": body.get("model"),
|
||||
"system_fingerprint": body.get("system_fingerprint"),
|
||||
}
|
||||
|
||||
chat_result = ChatResult(
|
||||
generations=generations,
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
chat_result = ChatResult(generations=generations, llm_output=llm_output)
|
||||
chat_results.append(chat_result)
|
||||
|
||||
|
||||
return chat_results
|
||||
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
@@ -425,7 +444,7 @@ class OpenAIBatchProcessor:
|
||||
) -> List[ChatResult]:
|
||||
"""
|
||||
Complete batch processing workflow: create, poll, and retrieve results.
|
||||
|
||||
|
||||
Args:
|
||||
messages_list: List of message sequences to process in batch.
|
||||
description: Optional description for the batch job.
|
||||
@@ -433,10 +452,10 @@ class OpenAIBatchProcessor:
|
||||
poll_interval: Time in seconds between status checks. Uses default if None.
|
||||
timeout: Maximum time in seconds to wait. Uses default if None.
|
||||
**kwargs: Additional parameters to pass to chat completions.
|
||||
|
||||
|
||||
Returns:
|
||||
List of ChatResult objects corresponding to the original message sequences.
|
||||
|
||||
|
||||
Raises:
|
||||
BatchError: If any step of the batch processing fails.
|
||||
"""
|
||||
@@ -445,50 +464,39 @@ class OpenAIBatchProcessor:
|
||||
messages_list=messages_list,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Poll until completion
|
||||
self.poll_batch_status(
|
||||
batch_id=batch_id,
|
||||
poll_interval=poll_interval,
|
||||
timeout=timeout,
|
||||
batch_id=batch_id, poll_interval=poll_interval, timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
# Retrieve and return results
|
||||
return self.retrieve_batch_results(batch_id)
|
||||
|
||||
|
||||
def create_batch_request(
|
||||
messages: List[BaseMessage],
|
||||
model: str,
|
||||
custom_id: str,
|
||||
**kwargs: Any,
|
||||
messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a batch request object from LangChain messages.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages to convert.
|
||||
model: The model to use for the request.
|
||||
custom_id: Unique identifier for this request within the batch.
|
||||
**kwargs: Additional parameters to pass to the chat completion.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary in OpenAI batch request format.
|
||||
"""
|
||||
# Convert LangChain messages to OpenAI format
|
||||
openai_messages = [convert_message_to_dict(msg) for msg in messages]
|
||||
|
||||
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": model,
|
||||
"messages": openai_messages,
|
||||
**kwargs
|
||||
}
|
||||
"body": {"model": model, "messages": openai_messages, **kwargs},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,20 +6,18 @@ They are designed to test the complete end-to-end batch processing workflow.
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.batch import BatchError
|
||||
|
||||
|
||||
# Skip all tests if no API key is available
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY"),
|
||||
reason="OPENAI_API_KEY not set, skipping integration tests"
|
||||
reason="OPENAI_API_KEY not set, skipping integration tests",
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +29,7 @@ class TestBatchAPIIntegration:
|
||||
self.llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.1, # Low temperature for consistent results
|
||||
max_tokens=50, # Keep responses short for faster processing
|
||||
max_tokens=50, # Keep responses short for faster processing
|
||||
)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@@ -40,7 +38,11 @@ class TestBatchAPIIntegration:
|
||||
# Create a small batch of simple questions
|
||||
messages_list = [
|
||||
[HumanMessage(content="What is 2+2? Answer with just the number.")],
|
||||
[HumanMessage(content="What is the capital of France? Answer with just the city name.")],
|
||||
[
|
||||
HumanMessage(
|
||||
content="What is the capital of France? Answer with just the city name."
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
# Create batch job
|
||||
@@ -58,19 +60,21 @@ class TestBatchAPIIntegration:
|
||||
results = self.llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=30.0, # Poll every 30 seconds
|
||||
timeout=1800.0, # 30 minute timeout
|
||||
timeout=1800.0, # 30 minute timeout
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 2
|
||||
assert all(isinstance(result, ChatResult) for result in results)
|
||||
assert all(len(result.generations) == 1 for result in results)
|
||||
assert all(isinstance(result.generations[0].message, AIMessage) for result in results)
|
||||
assert all(
|
||||
isinstance(result.generations[0].message, AIMessage) for result in results
|
||||
)
|
||||
|
||||
# Check that we got reasonable responses
|
||||
response1 = results[0].generations[0].message.content.strip()
|
||||
response2 = results[1].generations[0].message.content.strip()
|
||||
|
||||
|
||||
# Basic sanity checks (responses should contain expected content)
|
||||
assert "4" in response1 or "four" in response1.lower()
|
||||
assert "paris" in response2.lower()
|
||||
@@ -85,10 +89,7 @@ class TestBatchAPIIntegration:
|
||||
|
||||
# Use batch API mode
|
||||
results = self.llm.batch(
|
||||
inputs,
|
||||
use_batch_api=True,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@@ -99,37 +100,32 @@ class TestBatchAPIIntegration:
|
||||
# Basic sanity checks
|
||||
response1 = results[0].content.strip().lower()
|
||||
response2 = results[1].content.strip().lower()
|
||||
|
||||
|
||||
assert any(char in response1 for char in ["1", "2", "3"])
|
||||
assert "blue" in response2
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_method_comparison(self):
|
||||
"""Test that batch API and standard batch produce similar results."""
|
||||
inputs = [
|
||||
[HumanMessage(content="What is 1+1? Answer with just the number.")],
|
||||
]
|
||||
inputs = [[HumanMessage(content="What is 1+1? Answer with just the number.")]]
|
||||
|
||||
# Test standard batch processing
|
||||
standard_results = self.llm.batch(inputs, use_batch_api=False)
|
||||
|
||||
|
||||
# Test batch API processing
|
||||
batch_api_results = self.llm.batch(
|
||||
inputs,
|
||||
use_batch_api=True,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
|
||||
# Both should return similar structure
|
||||
assert len(standard_results) == len(batch_api_results) == 1
|
||||
assert isinstance(standard_results[0], AIMessage)
|
||||
assert isinstance(batch_api_results[0], AIMessage)
|
||||
|
||||
|
||||
# Both should contain reasonable answers
|
||||
standard_content = standard_results[0].content.strip()
|
||||
batch_content = batch_api_results[0].content.strip()
|
||||
|
||||
|
||||
assert "2" in standard_content or "two" in standard_content.lower()
|
||||
assert "2" in batch_content or "two" in batch_content.lower()
|
||||
|
||||
@@ -137,7 +133,7 @@ class TestBatchAPIIntegration:
|
||||
def test_batch_with_different_parameters(self):
|
||||
"""Test batch processing with different model parameters."""
|
||||
messages_list = [
|
||||
[HumanMessage(content="Write a haiku about coding. Keep it short.")],
|
||||
[HumanMessage(content="Write a haiku about coding. Keep it short.")]
|
||||
]
|
||||
|
||||
# Create batch with specific parameters
|
||||
@@ -146,18 +142,16 @@ class TestBatchAPIIntegration:
|
||||
description="Integration test - parameters",
|
||||
metadata={"test_type": "parameters"},
|
||||
temperature=0.8, # Higher temperature for creativity
|
||||
max_tokens=100, # More tokens for haiku
|
||||
max_tokens=100, # More tokens for haiku
|
||||
)
|
||||
|
||||
results = self.llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
batch_id=batch_id, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].message.content
|
||||
|
||||
|
||||
# Should have some content (haiku)
|
||||
assert len(result_content.strip()) > 10
|
||||
# Haikus typically have line breaks
|
||||
@@ -170,25 +164,24 @@ class TestBatchAPIIntegration:
|
||||
|
||||
messages_list = [
|
||||
[
|
||||
SystemMessage(content="You are a helpful math tutor. Answer concisely."),
|
||||
SystemMessage(
|
||||
content="You are a helpful math tutor. Answer concisely."
|
||||
),
|
||||
HumanMessage(content="What is 5 * 6?"),
|
||||
],
|
||||
]
|
||||
]
|
||||
|
||||
batch_id = self.llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Integration test - system message",
|
||||
messages_list=messages_list, description="Integration test - system message"
|
||||
)
|
||||
|
||||
results = self.llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
batch_id=batch_id, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].message.content.strip()
|
||||
|
||||
|
||||
# Should contain the answer
|
||||
assert "30" in result_content or "thirty" in result_content.lower()
|
||||
|
||||
@@ -196,14 +189,9 @@ class TestBatchAPIIntegration:
|
||||
def test_batch_error_handling_invalid_model(self):
|
||||
"""Test error handling with invalid model parameters."""
|
||||
# Create a ChatOpenAI instance with an invalid model
|
||||
invalid_llm = ChatOpenAI(
|
||||
model="invalid-model-name-12345",
|
||||
temperature=0.1,
|
||||
)
|
||||
invalid_llm = ChatOpenAI(model="invalid-model-name-12345", temperature=0.1)
|
||||
|
||||
messages_list = [
|
||||
[HumanMessage(content="Hello")],
|
||||
]
|
||||
messages_list = [[HumanMessage(content="Hello")]]
|
||||
|
||||
# This should fail during batch creation or processing
|
||||
with pytest.raises(BatchError):
|
||||
@@ -216,23 +204,24 @@ class TestBatchAPIIntegration:
|
||||
# Test with string inputs (should be converted to HumanMessage)
|
||||
inputs = [
|
||||
"What is the largest planet? Answer with just the planet name.",
|
||||
[HumanMessage(content="What is the smallest planet? Answer with just the planet name.")],
|
||||
[
|
||||
HumanMessage(
|
||||
content="What is the smallest planet? Answer with just the planet name."
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
results = self.llm.batch(
|
||||
inputs,
|
||||
use_batch_api=True,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(isinstance(result, AIMessage) for result in results)
|
||||
|
||||
|
||||
# Check for reasonable responses
|
||||
response1 = results[0].content.strip().lower()
|
||||
response2 = results[1].content.strip().lower()
|
||||
|
||||
|
||||
assert "jupiter" in response1
|
||||
assert "mercury" in response2
|
||||
|
||||
@@ -248,12 +237,10 @@ class TestBatchAPIIntegration:
|
||||
results = self.llm.batch_retrieve(batch_id, timeout=300.0)
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_metadata_preservation(self):
|
||||
"""Test that batch metadata is properly handled."""
|
||||
messages_list = [
|
||||
[HumanMessage(content="Say 'test successful'")],
|
||||
]
|
||||
messages_list = [[HumanMessage(content="Say 'test successful'")]]
|
||||
|
||||
metadata = {
|
||||
"test_name": "metadata_test",
|
||||
@@ -270,7 +257,7 @@ class TestBatchAPIIntegration:
|
||||
|
||||
# Retrieve results
|
||||
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
|
||||
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].message.content.strip().lower()
|
||||
assert "test successful" in result_content
|
||||
@@ -281,18 +268,12 @@ class TestBatchAPIEdgeCases:
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
max_tokens=50,
|
||||
)
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1, max_tokens=50)
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_batch_with_very_short_timeout(self):
|
||||
"""Test batch processing with very short timeout."""
|
||||
messages_list = [
|
||||
[HumanMessage(content="Hello")],
|
||||
]
|
||||
messages_list = [[HumanMessage(content="Hello")]]
|
||||
|
||||
batch_id = self.llm.batch_create(messages_list=messages_list)
|
||||
|
||||
@@ -313,10 +294,8 @@ class TestBatchAPIEdgeCases:
|
||||
def test_batch_with_long_content(self):
|
||||
"""Test batch processing with longer content."""
|
||||
long_content = "Please summarize this text: " + "This is a test sentence. " * 20
|
||||
|
||||
messages_list = [
|
||||
[HumanMessage(content=long_content)],
|
||||
]
|
||||
|
||||
messages_list = [[HumanMessage(content=long_content)]]
|
||||
|
||||
batch_id = self.llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
@@ -324,10 +303,10 @@ class TestBatchAPIEdgeCases:
|
||||
)
|
||||
|
||||
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
|
||||
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].message.content
|
||||
|
||||
|
||||
# Should have some summary content
|
||||
assert len(result_content.strip()) > 10
|
||||
|
||||
@@ -353,7 +332,7 @@ class TestBatchAPIPerformance:
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
batch_id = self.llm.batch_create(
|
||||
messages_list=messages_list,
|
||||
description="Medium batch test - 10 requests",
|
||||
@@ -363,7 +342,7 @@ class TestBatchAPIPerformance:
|
||||
results = self.llm.batch_retrieve(
|
||||
batch_id=batch_id,
|
||||
poll_interval=60.0, # Poll every minute
|
||||
timeout=3600.0, # 1 hour timeout
|
||||
timeout=3600.0, # 1 hour timeout
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@@ -401,20 +380,17 @@ class TestBatchAPIPerformance:
|
||||
# Test batch API processing time
|
||||
start_batch = time.time()
|
||||
batch_results = self.llm.batch(
|
||||
messages,
|
||||
use_batch_api=True,
|
||||
poll_interval=30.0,
|
||||
timeout=1800.0,
|
||||
messages, use_batch_api=True, poll_interval=30.0, timeout=1800.0
|
||||
)
|
||||
batch_time = time.time() - start_batch
|
||||
|
||||
# Verify both produce results
|
||||
assert len(sequential_results) == len(batch_results) == 2
|
||||
|
||||
|
||||
# Log timing comparison
|
||||
print(f"Sequential processing: {sequential_time:.2f}s")
|
||||
print(f"Batch API processing: {batch_time:.2f}s")
|
||||
|
||||
|
||||
# Note: Batch API will typically be slower for small batches due to polling,
|
||||
# but should be more cost-effective for larger batches
|
||||
|
||||
@@ -424,6 +400,7 @@ def is_openai_api_available() -> bool:
|
||||
"""Check if OpenAI API is available and accessible."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI()
|
||||
# Try a simple API call to verify connectivity
|
||||
client.models.list()
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Test OpenAI Batch API functionality."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -42,7 +40,10 @@ class TestOpenAIBatchClient:
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -64,7 +65,10 @@ class TestOpenAIBatchClient:
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -76,10 +80,10 @@ class TestOpenAIBatchClient:
|
||||
# Mock batch status progression
|
||||
mock_batch_validating = MagicMock()
|
||||
mock_batch_validating.status = "validating"
|
||||
|
||||
|
||||
mock_batch_in_progress = MagicMock()
|
||||
mock_batch_in_progress.status = "in_progress"
|
||||
|
||||
|
||||
mock_batch_completed = MagicMock()
|
||||
mock_batch_completed.status = "completed"
|
||||
mock_batch_completed.output_file_id = "file_123"
|
||||
@@ -134,13 +138,17 @@ class TestOpenAIBatchClient:
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you?"
|
||||
"content": "Hello! How can I help you?",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
|
||||
}
|
||||
}
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 8,
|
||||
"total_tokens": 18,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -177,7 +185,7 @@ class TestOpenAIBatchProcessor:
|
||||
def test_create_batch_success(self):
|
||||
"""Test successful batch creation with message conversion."""
|
||||
# Mock batch client
|
||||
with patch.object(self.processor, 'batch_client') as mock_batch_client:
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
mock_batch_client.create_batch.return_value = "batch_123"
|
||||
|
||||
messages_list = [
|
||||
@@ -198,7 +206,7 @@ class TestOpenAIBatchProcessor:
|
||||
# Verify batch requests were created correctly
|
||||
call_args = mock_batch_client.create_batch.call_args
|
||||
batch_requests = call_args[1]["batch_requests"]
|
||||
|
||||
|
||||
assert len(batch_requests) == 2
|
||||
assert batch_requests[0]["custom_id"] == "request-0"
|
||||
assert batch_requests[0]["body"]["model"] == "gpt-3.5-turbo"
|
||||
@@ -208,7 +216,7 @@ class TestOpenAIBatchProcessor:
|
||||
|
||||
def test_poll_batch_status_success(self):
|
||||
"""Test successful batch status polling."""
|
||||
with patch.object(self.processor, 'batch_client') as mock_batch_client:
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
mock_batch = MagicMock()
|
||||
mock_batch.status = "completed"
|
||||
mock_batch_client.poll_batch_status.return_value = mock_batch
|
||||
@@ -238,13 +246,17 @@ class TestOpenAIBatchProcessor:
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "2+2 equals 4."
|
||||
"content": "2+2 equals 4.",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
|
||||
}
|
||||
}
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 8,
|
||||
"total_tokens": 18,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "batch_req_124",
|
||||
@@ -256,35 +268,42 @@ class TestOpenAIBatchProcessor:
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The capital of France is Paris."
|
||||
"content": "The capital of France is Paris.",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 10, "total_tokens": 22}
|
||||
}
|
||||
}
|
||||
}
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 22,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(self.processor, 'batch_client') as mock_batch_client:
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
mock_batch_client.poll_batch_status.return_value = mock_batch
|
||||
mock_batch_client.retrieve_batch_results.return_value = mock_results
|
||||
|
||||
chat_results = self.processor.retrieve_batch_results("batch_123")
|
||||
|
||||
assert len(chat_results) == 2
|
||||
|
||||
|
||||
# Check first result
|
||||
assert isinstance(chat_results[0], ChatResult)
|
||||
assert len(chat_results[0].generations) == 1
|
||||
assert isinstance(chat_results[0].generations[0].message, AIMessage)
|
||||
assert chat_results[0].generations[0].message.content == "2+2 equals 4."
|
||||
|
||||
|
||||
# Check second result
|
||||
assert isinstance(chat_results[1], ChatResult)
|
||||
assert len(chat_results[1].generations) == 1
|
||||
assert isinstance(chat_results[1].generations[0].message, AIMessage)
|
||||
assert chat_results[1].generations[0].message.content == "The capital of France is Paris."
|
||||
assert (
|
||||
chat_results[1].generations[0].message.content
|
||||
== "The capital of France is Paris."
|
||||
)
|
||||
|
||||
def test_retrieve_batch_results_with_errors(self):
|
||||
"""Test batch result retrieval with some failed requests."""
|
||||
@@ -303,13 +322,17 @@ class TestOpenAIBatchProcessor:
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Success response"
|
||||
"content": "Success response",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}
|
||||
}
|
||||
}
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 8,
|
||||
"total_tokens": 18,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "batch_req_124",
|
||||
@@ -319,14 +342,14 @@ class TestOpenAIBatchProcessor:
|
||||
"body": {
|
||||
"error": {
|
||||
"message": "Invalid request",
|
||||
"type": "invalid_request_error"
|
||||
"type": "invalid_request_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(self.processor, 'batch_client') as mock_batch_client:
|
||||
with patch.object(self.processor, "batch_client") as mock_batch_client:
|
||||
mock_batch_client.poll_batch_status.return_value = mock_batch
|
||||
mock_batch_client.retrieve_batch_results.return_value = mock_results
|
||||
|
||||
@@ -341,7 +364,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_create_success(self, mock_processor_class):
|
||||
"""Test successful batch creation."""
|
||||
mock_processor = MagicMock()
|
||||
@@ -369,13 +392,21 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_retrieve_success(self, mock_processor_class):
|
||||
"""Test successful batch result retrieval."""
|
||||
mock_processor = MagicMock()
|
||||
mock_chat_results = [
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]),
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="The capital of France is Paris."))]),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]
|
||||
),
|
||||
ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(content="The capital of France is Paris.")
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
mock_processor.retrieve_batch_results.return_value = mock_chat_results
|
||||
mock_processor_class.return_value = mock_processor
|
||||
@@ -384,22 +415,27 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].generations[0].message.content == "2+2 equals 4."
|
||||
assert results[1].generations[0].message.content == "The capital of France is Paris."
|
||||
|
||||
assert (
|
||||
results[1].generations[0].message.content
|
||||
== "The capital of France is Paris."
|
||||
)
|
||||
|
||||
mock_processor.poll_batch_status.assert_called_once_with(
|
||||
batch_id="batch_123",
|
||||
poll_interval=1.0,
|
||||
timeout=60.0,
|
||||
batch_id="batch_123", poll_interval=1.0, timeout=60.0
|
||||
)
|
||||
mock_processor.retrieve_batch_results.assert_called_once_with("batch_123")
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_method_with_batch_api_true(self, mock_processor_class):
|
||||
"""Test batch method with use_batch_api=True."""
|
||||
mock_processor = MagicMock()
|
||||
mock_chat_results = [
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]),
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response 1"))]
|
||||
),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response 2"))]
|
||||
),
|
||||
]
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
mock_processor.retrieve_batch_results.return_value = mock_chat_results
|
||||
@@ -429,7 +465,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
]
|
||||
|
||||
# Mock the parent class batch method
|
||||
with patch.object(ChatOpenAI.__bases__[0], 'batch') as mock_super_batch:
|
||||
with patch.object(ChatOpenAI.__bases__[0], "batch") as mock_super_batch:
|
||||
mock_super_batch.return_value = [
|
||||
AIMessage(content="Response 1"),
|
||||
AIMessage(content="Response 2"),
|
||||
@@ -439,9 +475,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
|
||||
assert len(results) == 2
|
||||
mock_super_batch.assert_called_once_with(
|
||||
inputs=inputs,
|
||||
config=None,
|
||||
return_exceptions=False,
|
||||
inputs=inputs, config=None, return_exceptions=False
|
||||
)
|
||||
|
||||
def test_convert_input_to_messages_list(self):
|
||||
@@ -462,7 +496,7 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
assert isinstance(result[0], HumanMessage)
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_create_with_error_handling(self, mock_processor_class):
|
||||
"""Test batch creation with error handling."""
|
||||
mock_processor = MagicMock()
|
||||
@@ -474,11 +508,13 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
with pytest.raises(BatchError, match="Batch creation failed"):
|
||||
self.llm.batch_create(messages_list)
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_batch_retrieve_with_error_handling(self, mock_processor_class):
|
||||
"""Test batch retrieval with error handling."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.poll_batch_status.side_effect = BatchError("Batch polling failed")
|
||||
mock_processor.poll_batch_status.side_effect = BatchError(
|
||||
"Batch polling failed"
|
||||
)
|
||||
mock_processor_class.return_value = mock_processor
|
||||
|
||||
with pytest.raises(BatchError, match="Batch polling failed"):
|
||||
@@ -486,12 +522,15 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
|
||||
def test_batch_method_input_conversion(self):
|
||||
"""Test batch method handles various input formats correctly."""
|
||||
with patch.object(self.llm, 'batch_create') as mock_create, \
|
||||
patch.object(self.llm, 'batch_retrieve') as mock_retrieve:
|
||||
|
||||
with (
|
||||
patch.object(self.llm, "batch_create") as mock_create,
|
||||
patch.object(self.llm, "batch_retrieve") as mock_retrieve,
|
||||
):
|
||||
mock_create.return_value = "batch_123"
|
||||
mock_retrieve.return_value = [
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response"))]),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response"))]
|
||||
)
|
||||
]
|
||||
|
||||
# Test with string inputs
|
||||
@@ -501,8 +540,8 @@ class TestBaseChatOpenAIBatchMethods:
|
||||
# Verify conversion happened
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args[1]
|
||||
messages_list = call_args['messages_list']
|
||||
|
||||
messages_list = call_args["messages_list"]
|
||||
|
||||
assert len(messages_list) == 1
|
||||
assert len(messages_list[0]) == 1
|
||||
assert isinstance(messages_list[0][0], HumanMessage)
|
||||
@@ -532,7 +571,7 @@ class TestBatchIntegrationScenarios:
|
||||
"""Set up test fixtures."""
|
||||
self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key")
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_empty_messages_list(self, mock_processor_class):
|
||||
"""Test handling of empty messages list."""
|
||||
mock_processor = MagicMock()
|
||||
@@ -543,16 +582,18 @@ class TestBatchIntegrationScenarios:
|
||||
results = self.llm.batch([], use_batch_api=True)
|
||||
assert results == []
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_large_batch_processing(self, mock_processor_class):
|
||||
"""Test processing of large batch."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
|
||||
|
||||
# Create mock results for large batch
|
||||
num_requests = 100
|
||||
mock_chat_results = [
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))])
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))]
|
||||
)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
mock_processor.retrieve_batch_results.return_value = mock_chat_results
|
||||
@@ -565,14 +606,18 @@ class TestBatchIntegrationScenarios:
|
||||
for i, result in enumerate(results):
|
||||
assert result.content == f"Response {i}"
|
||||
|
||||
@patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor')
|
||||
@patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor")
|
||||
def test_mixed_message_types(self, mock_processor_class):
|
||||
"""Test batch processing with mixed message types."""
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.create_batch.return_value = "batch_123"
|
||||
mock_processor.retrieve_batch_results.return_value = [
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]),
|
||||
ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response 1"))]
|
||||
),
|
||||
ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response 2"))]
|
||||
),
|
||||
]
|
||||
mock_processor_class.return_value = mock_processor
|
||||
|
||||
@@ -587,12 +632,12 @@ class TestBatchIntegrationScenarios:
|
||||
# Verify the conversion happened correctly
|
||||
mock_processor.create_batch.assert_called_once()
|
||||
call_args = mock_processor.create_batch.call_args[1]
|
||||
messages_list = call_args['messages_list']
|
||||
|
||||
messages_list = call_args["messages_list"]
|
||||
|
||||
# First input should be converted to HumanMessage
|
||||
assert isinstance(messages_list[0][0], HumanMessage)
|
||||
assert messages_list[0][0].content == "String input"
|
||||
|
||||
|
||||
# Second input should remain as is
|
||||
assert isinstance(messages_list[1][0], HumanMessage)
|
||||
assert messages_list[1][0].content == "Direct message list"
|
||||
|
||||
Reference in New Issue
Block a user