From c5c43e3ced067a20d489907273259330c079f0cb Mon Sep 17 00:00:00 2001 From: "open-swe[bot]" Date: Mon, 11 Aug 2025 20:33:56 +0000 Subject: [PATCH] Apply patch [skip ci] --- .../openai/langchain_openai/__init__.py | 8 +- .../langchain_openai/chat_models/__init__.py | 1 - .../langchain_openai/chat_models/base.py | 302 +++++++++--------- .../langchain_openai/chat_models/batch.py | 272 ++++++++-------- .../chat_models/test_batch_integration.py | 135 ++++---- .../unit_tests/chat_models/test_batch.py | 189 ++++++----- 6 files changed, 466 insertions(+), 441 deletions(-) diff --git a/libs/partners/openai/langchain_openai/__init__.py b/libs/partners/openai/langchain_openai/__init__.py index 174da8f2813..dd62241195b 100644 --- a/libs/partners/openai/langchain_openai/__init__.py +++ b/libs/partners/openai/langchain_openai/__init__.py @@ -1,4 +1,9 @@ -from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI, BatchError, BatchStatus +from langchain_openai.chat_models import ( + AzureChatOpenAI, + BatchError, + BatchStatus, + ChatOpenAI, +) from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai.llms import AzureOpenAI, OpenAI from langchain_openai.tools import custom_tool @@ -14,4 +19,3 @@ __all__ = [ "BatchError", "BatchStatus", ] - diff --git a/libs/partners/openai/langchain_openai/chat_models/__init__.py b/libs/partners/openai/langchain_openai/chat_models/__init__.py index e1acbb4b87e..85caac42601 100644 --- a/libs/partners/openai/langchain_openai/chat_models/__init__.py +++ b/libs/partners/openai/langchain_openai/chat_models/__init__.py @@ -3,4 +3,3 @@ from langchain_openai.chat_models.base import ChatOpenAI from langchain_openai.chat_models.batch import BatchError, BatchStatus __all__ = ["ChatOpenAI", "AzureChatOpenAI", "BatchError", "BatchStatus"] - diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 30787dd3dbf..66cf27490c8 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1503,142 +1503,142 @@ class BaseChatOpenAI(BaseChatModel): ) -> int: """Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package. - **Requirements**: You must have the ``pillow`` installed if you want to count - image tokens if you are specifying the image as a base64 string, and you must - have both ``pillow`` and ``httpx`` installed if you are specifying the image - as a URL. If these aren't installed image inputs will be ignored in token - counting. + **Requirements**: You must have the ``pillow`` installed if you want to count + image tokens if you are specifying the image as a base64 string, and you must + have both ``pillow`` and ``httpx`` installed if you are specifying the image + as a URL. If these aren't installed image inputs will be ignored in token + counting. - `OpenAI reference `__ + `OpenAI reference `__ - Args: - messages: The message inputs to tokenize. - tools: If provided, sequence of dict, BaseModel, function, or BaseTools - to be converted to tool schemas. - .. dropdown:: Batch API for cost savings + Args: + messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. + .. dropdown:: Batch API for cost savings - .. versionadded:: 0.3.7 + .. versionadded:: 0.3.7 - OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by - processing requests asynchronously. This is ideal for tasks like data processing, - content generation, or evaluation that don't require immediate responses. + OpenAI's Batch API provides **50% cost savings** for non-real-time workloads by + processing requests asynchronously. This is ideal for tasks like data processing, + content generation, or evaluation that don't require immediate responses. - **Cost vs Latency Tradeoff:** + **Cost vs Latency Tradeoff:** - - **Standard API**: Immediate results, full pricing - - **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours) + - **Standard API**: Immediate results, full pricing + - **Batch API**: 50% cost savings, asynchronous processing (results available within 24 hours) - **Method 1: Direct batch management** + **Method 1: Direct batch management** - Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle: + Use ``batch_create()`` and ``batch_retrieve()`` for full control over batch lifecycle: - .. code-block:: python + .. code-block:: python - from langchain_openai import ChatOpenAI - from langchain_core.messages import HumanMessage + from langchain_openai import ChatOpenAI + from langchain_core.messages import HumanMessage - llm = ChatOpenAI(model="gpt-3.5-turbo") + llm = ChatOpenAI(model="gpt-3.5-turbo") - # Prepare multiple message sequences for batch processing - messages_list = [ - [HumanMessage(content="Translate 'hello' to French")], - [HumanMessage(content="Translate 'goodbye' to Spanish")], - [HumanMessage(content="What is the capital of Italy?")], - ] + # Prepare multiple message sequences for batch processing + messages_list = [ + [HumanMessage(content="Translate 'hello' to French")], + [HumanMessage(content="Translate 'goodbye' to Spanish")], + [HumanMessage(content="What is the capital of Italy?")], + ] - # Create batch job (returns immediately with batch ID) - batch_id = llm.batch_create( - messages_list=messages_list, - description="Translation and geography batch", - metadata={"project": "multilingual_qa", "user": "analyst_1"}, - ) - print(f"Batch created: {batch_id}") + # Create batch job (returns immediately with batch ID) + batch_id = llm.batch_create( + messages_list=messages_list, + description="Translation and geography batch", + metadata={"project": "multilingual_qa", "user": "analyst_1"}, + ) + print(f"Batch created: {batch_id}") - # Later, retrieve results (polls until completion) - results = llm.batch_retrieve( - batch_id=batch_id, - poll_interval=60.0, # Check every minute - timeout=3600.0, # 1 hour timeout - ) + # Later, retrieve results (polls until completion) + results = llm.batch_retrieve( + batch_id=batch_id, + poll_interval=60.0, # Check every minute + timeout=3600.0, # 1 hour timeout + ) - # Process results - for i, result in enumerate(results): - response = result.generations[0].message.content - print(f"Response {i+1}: {response}") + # Process results + for i, result in enumerate(results): + response = result.generations[0].message.content + print(f"Response {i + 1}: {response}") - **Method 2: Enhanced batch() method** + **Method 2: Enhanced batch() method** - Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration: + Use the familiar ``batch()`` method with ``use_batch_api=True`` for seamless integration: - .. code-block:: python + .. code-block:: python - # Standard batch processing (immediate, full cost) - inputs = [ - [HumanMessage(content="What is 2+2?")], - [HumanMessage(content="What is 3+3?")], - ] - standard_results = llm.batch(inputs) # Default: use_batch_api=False + # Standard batch processing (immediate, full cost) + inputs = [ + [HumanMessage(content="What is 2+2?")], + [HumanMessage(content="What is 3+3?")], + ] + standard_results = llm.batch(inputs) # Default: use_batch_api=False - # Batch API processing (50% cost savings, polling) - batch_results = llm.batch( - inputs, - use_batch_api=True, # Enable cost savings - poll_interval=30.0, # Poll every 30 seconds - timeout=1800.0, # 30 minute timeout - ) + # Batch API processing (50% cost savings, polling) + batch_results = llm.batch( + inputs, + use_batch_api=True, # Enable cost savings + poll_interval=30.0, # Poll every 30 seconds + timeout=1800.0, # 30 minute timeout + ) - **Batch creation with custom parameters:** + **Batch creation with custom parameters:** - .. code-block:: python + .. code-block:: python - # Create batch with specific model parameters - batch_id = llm.batch_create( - messages_list=messages_list, - description="Creative writing batch", - metadata={"task_type": "content_generation"}, - temperature=0.8, # Higher creativity - max_tokens=200, # Longer responses - top_p=0.9, # Nucleus sampling - ) + # Create batch with specific model parameters + batch_id = llm.batch_create( + messages_list=messages_list, + description="Creative writing batch", + metadata={"task_type": "content_generation"}, + temperature=0.8, # Higher creativity + max_tokens=200, # Longer responses + top_p=0.9, # Nucleus sampling + ) - **Error handling and monitoring:** + **Error handling and monitoring:** - .. code-block:: python + .. code-block:: python - from langchain_openai.chat_models.batch import BatchError + from langchain_openai.chat_models.batch import BatchError - try: - batch_id = llm.batch_create(messages_list) - results = llm.batch_retrieve(batch_id, timeout=600.0) - except BatchError as e: - print(f"Batch processing failed: {e}") - # Handle batch failure (retry, fallback to standard API, etc.) + try: + batch_id = llm.batch_create(messages_list) + results = llm.batch_retrieve(batch_id, timeout=600.0) + except BatchError as e: + print(f"Batch processing failed: {e}") + # Handle batch failure (retry, fallback to standard API, etc.) - **Best practices:** + **Best practices:** - - Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times - - Set appropriate **timeouts** based on batch size (larger batches take longer) - - Include **descriptive metadata** for tracking and debugging batch jobs - - Consider **fallback strategies** for time-sensitive applications - - Monitor batch status for **long-running jobs** to detect failures early + - Use batch API for **non-urgent tasks** where 50% cost savings justify longer wait times + - Set appropriate **timeouts** based on batch size (larger batches take longer) + - Include **descriptive metadata** for tracking and debugging batch jobs + - Consider **fallback strategies** for time-sensitive applications + - Monitor batch status for **long-running jobs** to detect failures early - **When to use Batch API:** + **When to use Batch API:** - ✅ **Good for:** - - Data processing and analysis - - Content generation at scale - - Model evaluation and testing - - Batch translation or summarization - - Non-interactive applications + ✅ **Good for:** + - Data processing and analysis + - Content generation at scale + - Model evaluation and testing + - Batch translation or summarization + - Non-interactive applications + + ❌ **Not suitable for:** + - Real-time chat applications + - Interactive user interfaces + - Time-critical decision making + - Applications requiring immediate responses - ❌ **Not suitable for:** - - Real-time chat applications - - Interactive user interfaces - - Time-critical decision making - - Applications requiring immediate responses - """ # noqa: E501 # TODO: Count bound tools as part of input. if tools is not None: @@ -2147,6 +2147,7 @@ class BaseChatOpenAI(BaseChatModel): else: filtered[k] = v return filtered + def batch_create( self, messages_list: List[List[BaseMessage]], @@ -2159,10 +2160,10 @@ class BaseChatOpenAI(BaseChatModel): ) -> str: """ Create a batch job using OpenAI's Batch API for asynchronous processing. - + This method provides 50% cost savings compared to the standard API in exchange for asynchronous processing with polling for results. - + Args: messages_list: List of message sequences to process in batch. description: Optional description for the batch job. @@ -2170,34 +2171,34 @@ class BaseChatOpenAI(BaseChatModel): poll_interval: Default time in seconds between status checks when polling. timeout: Default maximum time in seconds to wait for completion. **kwargs: Additional parameters to pass to chat completions. - + Returns: The batch ID for tracking the asynchronous job. - + Raises: BatchError: If batch creation fails. - + Example: .. code-block:: python - + from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage - + llm = ChatOpenAI() messages_list = [ [HumanMessage(content="What is 2+2?")], [HumanMessage(content="What is the capital of France?")], ] - + # Create batch job (50% cost savings) batch_id = llm.batch_create(messages_list) - + # Later, retrieve results results = llm.batch_retrieve(batch_id) """ # Import here to avoid circular imports from langchain_openai.chat_models.batch import OpenAIBatchProcessor - + # Create batch processor with current model settings processor = OpenAIBatchProcessor( client=self.root_client, @@ -2205,19 +2206,19 @@ class BaseChatOpenAI(BaseChatModel): poll_interval=poll_interval, timeout=timeout, ) - + # Filter and prepare kwargs for batch processing batch_kwargs = self._get_invocation_params(**kwargs) # Remove model from kwargs since it's handled by the processor batch_kwargs.pop("model", None) - + return processor.create_batch( messages_list=messages_list, description=description, metadata=metadata, **batch_kwargs, ) - + def batch_retrieve( self, batch_id: str, @@ -2227,36 +2228,36 @@ class BaseChatOpenAI(BaseChatModel): ) -> List[ChatResult]: """ Retrieve results from a batch job, polling until completion if necessary. - + This method will poll the batch status until completion and return the results converted to LangChain ChatResult format. - + Args: batch_id: The batch ID returned from batch_create(). poll_interval: Time in seconds between status checks. Uses default if None. timeout: Maximum time in seconds to wait. Uses default if None. - + Returns: List of ChatResult objects corresponding to the original message sequences. - + Raises: BatchError: If batch retrieval fails, times out, or batch job failed. - + Example: .. code-block:: python - + # After creating a batch job batch_id = llm.batch_create(messages_list) - + # Retrieve results (will poll until completion) results = llm.batch_retrieve(batch_id) - + for result in results: print(result.generations[0].message.content) """ # Import here to avoid circular imports from langchain_openai.chat_models.batch import OpenAIBatchProcessor - + # Create batch processor with current model settings processor = OpenAIBatchProcessor( client=self.root_client, @@ -2264,15 +2265,14 @@ class BaseChatOpenAI(BaseChatModel): poll_interval=poll_interval or 10.0, timeout=timeout, ) - + # Poll for completion and retrieve results processor.poll_batch_status( - batch_id=batch_id, - poll_interval=poll_interval, - timeout=timeout, + batch_id=batch_id, poll_interval=poll_interval, timeout=timeout ) - + return processor.retrieve_batch_results(batch_id) + @override def batch( self, @@ -2285,11 +2285,11 @@ class BaseChatOpenAI(BaseChatModel): ) -> List[BaseMessage]: """ Batch process multiple inputs using either standard API or OpenAI Batch API. - + This method provides two processing modes: 1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results 2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings - + Args: inputs: List of inputs to process in batch. config: Configuration for the batch processing. @@ -2297,33 +2297,33 @@ class BaseChatOpenAI(BaseChatModel): use_batch_api: If True, use OpenAI's Batch API for cost savings with polling. If False (default), use standard parallel processing for immediate results. **kwargs: Additional parameters to pass to the underlying model. - + Returns: List of BaseMessage objects corresponding to the inputs. - + Raises: BatchError: If batch processing fails (when use_batch_api=True). - + Note: **Cost vs Latency Tradeoff:** - use_batch_api=False: Immediate results, standard API pricing - use_batch_api=True: 50% cost savings, asynchronous processing with polling - + Example: .. code-block:: python - + from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage - + llm = ChatOpenAI() inputs = [ [HumanMessage(content="What is 2+2?")], [HumanMessage(content="What is the capital of France?")], ] - + # Standard processing (immediate results) results = llm.batch(inputs) - + # Batch API processing (50% cost savings, polling required) results = llm.batch(inputs, use_batch_api=True) """ @@ -2338,11 +2338,11 @@ class BaseChatOpenAI(BaseChatModel): # Convert single input to list of messages messages = self._convert_input_to_messages(input_item) messages_list.append(messages) - + # Create batch job and poll for results batch_id = self.batch_create(messages_list, **kwargs) chat_results = self.batch_retrieve(batch_id) - + # Convert ChatResult objects to BaseMessage objects return [result.generations[0].message for result in chat_results] else: @@ -2354,7 +2354,9 @@ class BaseChatOpenAI(BaseChatModel): **kwargs, ) - def _convert_input_to_messages(self, input_item: LanguageModelInput) -> List[BaseMessage]: + def _convert_input_to_messages( + self, input_item: LanguageModelInput + ) -> List[BaseMessage]: """Convert various input formats to a list of BaseMessage objects.""" if isinstance(input_item, list): # Already a list of messages @@ -2365,27 +2367,17 @@ class BaseChatOpenAI(BaseChatModel): elif isinstance(input_item, str): # String input - convert to HumanMessage from langchain_core.messages import HumanMessage + return [HumanMessage(content=input_item)] - elif hasattr(input_item, 'to_messages'): + elif hasattr(input_item, "to_messages"): # PromptValue or similar return input_item.to_messages() else: # Try to convert to string and then to HumanMessage from langchain_core.messages import HumanMessage + return [HumanMessage(content=str(input_item))] - - - - - - - - - - - - def _get_generation_chunk_from_completion( self, completion: openai.BaseModel ) -> ChatGenerationChunk: diff --git a/libs/partners/openai/langchain_openai/chat_models/batch.py b/libs/partners/openai/langchain_openai/chat_models/batch.py index ba369a52faa..bdb19095769 100644 --- a/libs/partners/openai/langchain_openai/chat_models/batch.py +++ b/libs/partners/openai/langchain_openai/chat_models/batch.py @@ -5,19 +5,22 @@ from __future__ import annotations import json import time from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from uuid import uuid4 import openai from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGeneration, ChatResult -from langchain_openai.chat_models._compat import convert_dict_to_message, convert_message_to_dict +from langchain_openai.chat_models._compat import ( + convert_dict_to_message, + convert_message_to_dict, +) class BatchStatus(str, Enum): """OpenAI Batch API status values.""" - + VALIDATING = "validating" FAILED = "failed" IN_PROGRESS = "in_progress" @@ -30,8 +33,10 @@ class BatchStatus(str, Enum): class BatchError(Exception): """Exception raised when batch processing fails.""" - - def __init__(self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None): + + def __init__( + self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None + ): super().__init__(message) self.batch_id = batch_id self.status = status @@ -39,22 +44,22 @@ class BatchError(Exception): class OpenAIBatchClient: """ - OpenAI Batch API client wrapper that handles batch creation, status polling, + OpenAI Batch API client wrapper that handles batch creation, status polling, and result retrieval. - + This class provides a high-level interface to OpenAI's Batch API, which offers 50% cost savings compared to the standard API in exchange for asynchronous processing. """ - + def __init__(self, client: openai.OpenAI): """ Initialize the batch client. - + Args: client: OpenAI client instance to use for API calls. """ self.client = client - + def create_batch( self, requests: List[Dict[str, Any]], @@ -63,53 +68,52 @@ class OpenAIBatchClient: ) -> str: """ Create a new batch job with the OpenAI Batch API. - + Args: requests: List of request objects in OpenAI batch format. description: Optional description for the batch job. metadata: Optional metadata to attach to the batch job. - + Returns: The batch ID for tracking the job. - + Raises: BatchError: If batch creation fails. """ try: # Create JSONL content for the batch jsonl_content = "\n".join(json.dumps(req) for req in requests) - + # Upload the batch file file_response = self.client.files.create( - file=jsonl_content.encode('utf-8'), - purpose="batch" + file=jsonl_content.encode("utf-8"), purpose="batch" ) - + # Create the batch job batch_response = self.client.batches.create( input_file_id=file_response.id, endpoint="/v1/chat/completions", completion_window="24h", - metadata=metadata or {} + metadata=metadata or {}, ) - + return batch_response.id - + except openai.OpenAIError as e: raise BatchError(f"Failed to create batch: {e}") from e except Exception as e: raise BatchError(f"Unexpected error creating batch: {e}") from e - + def retrieve_batch(self, batch_id: str) -> Dict[str, Any]: """ Retrieve batch information by ID. - + Args: batch_id: The batch ID to retrieve. - + Returns: Dictionary containing batch information including status. - + Raises: BatchError: If batch retrieval fails. """ @@ -119,20 +123,24 @@ class OpenAIBatchClient: "id": batch.id, "status": batch.status, "created_at": batch.created_at, - "completed_at": getattr(batch, 'completed_at', None), - "failed_at": getattr(batch, 'failed_at', None), - "expired_at": getattr(batch, 'expired_at', None), - "request_counts": getattr(batch, 'request_counts', {}), - "metadata": getattr(batch, 'metadata', {}), - "errors": getattr(batch, 'errors', None), - "output_file_id": getattr(batch, 'output_file_id', None), - "error_file_id": getattr(batch, 'error_file_id', None), + "completed_at": getattr(batch, "completed_at", None), + "failed_at": getattr(batch, "failed_at", None), + "expired_at": getattr(batch, "expired_at", None), + "request_counts": getattr(batch, "request_counts", {}), + "metadata": getattr(batch, "metadata", {}), + "errors": getattr(batch, "errors", None), + "output_file_id": getattr(batch, "output_file_id", None), + "error_file_id": getattr(batch, "error_file_id", None), } except openai.OpenAIError as e: - raise BatchError(f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id) from e + raise BatchError( + f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id + ) from e except Exception as e: - raise BatchError(f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id) from e - + raise BatchError( + f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id + ) from e + def poll_batch_status( self, batch_id: str, @@ -141,94 +149,106 @@ class OpenAIBatchClient: ) -> Dict[str, Any]: """ Poll batch status until completion or failure. - + Args: batch_id: The batch ID to poll. poll_interval: Time in seconds between status checks. timeout: Maximum time in seconds to wait. None for no timeout. - + Returns: Final batch information when completed. - + Raises: BatchError: If batch fails or times out. """ start_time = time.time() - + while True: batch_info = self.retrieve_batch(batch_id) status = batch_info["status"] - + if status == BatchStatus.COMPLETED: return batch_info - elif status in [BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED]: + elif status in [ + BatchStatus.FAILED, + BatchStatus.EXPIRED, + BatchStatus.CANCELLED, + ]: error_msg = f"Batch {batch_id} failed with status: {status}" if batch_info.get("errors"): error_msg += f". Errors: {batch_info['errors']}" raise BatchError(error_msg, batch_id=batch_id, status=status) - + # Check timeout if timeout and (time.time() - start_time) > timeout: raise BatchError( f"Batch {batch_id} timed out after {timeout} seconds. Current status: {status}", batch_id=batch_id, - status=status + status=status, ) - + time.sleep(poll_interval) - + def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]: """ Retrieve results from a completed batch. - + Args: batch_id: The batch ID to retrieve results for. - + Returns: List of result objects from the batch. - + Raises: BatchError: If batch is not completed or result retrieval fails. """ try: batch_info = self.retrieve_batch(batch_id) - + if batch_info["status"] != BatchStatus.COMPLETED: raise BatchError( f"Batch {batch_id} is not completed. Current status: {batch_info['status']}", batch_id=batch_id, - status=batch_info["status"] + status=batch_info["status"], ) - + output_file_id = batch_info.get("output_file_id") if not output_file_id: - raise BatchError(f"No output file found for batch {batch_id}", batch_id=batch_id) - + raise BatchError( + f"No output file found for batch {batch_id}", batch_id=batch_id + ) + # Download and parse the results file file_content = self.client.files.content(output_file_id) results = [] - - for line in file_content.text.strip().split('\n'): + + for line in file_content.text.strip().split("\n"): if line.strip(): results.append(json.loads(line)) - + return results - + except openai.OpenAIError as e: - raise BatchError(f"Failed to retrieve results for batch {batch_id}: {e}", batch_id=batch_id) from e + raise BatchError( + f"Failed to retrieve results for batch {batch_id}: {e}", + batch_id=batch_id, + ) from e except Exception as e: - raise BatchError(f"Unexpected error retrieving results for batch {batch_id}: {e}", batch_id=batch_id) from e - + raise BatchError( + f"Unexpected error retrieving results for batch {batch_id}: {e}", + batch_id=batch_id, + ) from e + def cancel_batch(self, batch_id: str) -> Dict[str, Any]: """ Cancel a batch job. - + Args: batch_id: The batch ID to cancel. - + Returns: Updated batch information after cancellation. - + Raises: BatchError: If batch cancellation fails. """ @@ -236,22 +256,26 @@ class OpenAIBatchClient: batch = self.client.batches.cancel(batch_id) return self.retrieve_batch(batch_id) except openai.OpenAIError as e: - raise BatchError(f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id) from e + raise BatchError( + f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id + ) from e except Exception as e: - raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e + raise BatchError( + f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id + ) from e class OpenAIBatchProcessor: """ High-level processor for managing OpenAI Batch API lifecycle with LangChain integration. - + This class handles the complete batch processing workflow: 1. Converts LangChain messages to OpenAI batch format 2. Creates batch jobs using the OpenAI Batch API 3. Polls for completion with configurable intervals 4. Converts results back to LangChain format """ - + def __init__( self, client: openai.OpenAI, @@ -261,7 +285,7 @@ class OpenAIBatchProcessor: ): """ Initialize the batch processor. - + Args: client: OpenAI client instance to use for API calls. model: The model to use for batch requests. @@ -272,7 +296,7 @@ class OpenAIBatchProcessor: self.model = model self.poll_interval = poll_interval self.timeout = timeout - + def create_batch( self, messages_list: List[List[BaseMessage]], @@ -282,16 +306,16 @@ class OpenAIBatchProcessor: ) -> str: """ Create a batch job from a list of LangChain message sequences. - + Args: messages_list: List of message sequences to process in batch. description: Optional description for the batch job. metadata: Optional metadata to attach to the batch job. **kwargs: Additional parameters to pass to chat completions. - + Returns: The batch ID for tracking the job. - + Raises: BatchError: If batch creation fails. """ @@ -300,19 +324,14 @@ class OpenAIBatchProcessor: for i, messages in enumerate(messages_list): custom_id = f"request_{i}_{uuid4().hex[:8]}" request = create_batch_request( - messages=messages, - model=self.model, - custom_id=custom_id, - **kwargs + messages=messages, model=self.model, custom_id=custom_id, **kwargs ) requests.append(request) - + return self.batch_client.create_batch( - requests=requests, - description=description, - metadata=metadata, + requests=requests, description=description, metadata=metadata ) - + def poll_batch_status( self, batch_id: str, @@ -321,15 +340,15 @@ class OpenAIBatchProcessor: ) -> Dict[str, Any]: """ Poll batch status until completion or failure. - + Args: batch_id: The batch ID to poll. poll_interval: Time in seconds between status checks. Uses default if None. timeout: Maximum time in seconds to wait. Uses default if None. - + Returns: Final batch information when completed. - + Raises: BatchError: If batch fails or times out. """ @@ -338,26 +357,26 @@ class OpenAIBatchProcessor: poll_interval=poll_interval or self.poll_interval, timeout=timeout or self.timeout, ) - + def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]: """ Retrieve and convert batch results to LangChain format. - + Args: batch_id: The batch ID to retrieve results for. - + Returns: List of ChatResult objects corresponding to the original message sequences. - + Raises: BatchError: If batch is not completed or result retrieval fails. """ # Get raw results from OpenAI raw_results = self.batch_client.retrieve_batch_results(batch_id) - + # Sort results by custom_id to maintain order raw_results.sort(key=lambda x: x.get("custom_id", "")) - + # Convert to LangChain ChatResult format chat_results = [] for result in raw_results: @@ -365,39 +384,42 @@ class OpenAIBatchProcessor: # Handle individual request errors error_msg = f"Request failed: {result['error']}" raise BatchError(error_msg, batch_id=batch_id) - + response = result.get("response", {}) if not response: - raise BatchError(f"No response found in result: {result}", batch_id=batch_id) - + raise BatchError( + f"No response found in result: {result}", batch_id=batch_id + ) + body = response.get("body", {}) choices = body.get("choices", []) - + if not choices: - raise BatchError(f"No choices found in response: {body}", batch_id=batch_id) - + raise BatchError( + f"No choices found in response: {body}", batch_id=batch_id + ) + # Convert OpenAI response to LangChain format generations = [] for choice in choices: message_dict = choice.get("message", {}) if not message_dict: continue - + # Convert OpenAI message dict to LangChain message message = convert_dict_to_message(message_dict) - + # Create ChatGeneration with metadata generation_info = { "finish_reason": choice.get("finish_reason"), "logprobs": choice.get("logprobs"), } - + generation = ChatGeneration( - message=message, - generation_info=generation_info, + message=message, generation_info=generation_info ) generations.append(generation) - + # Create ChatResult with usage information usage = body.get("usage", {}) llm_output = { @@ -405,15 +427,12 @@ class OpenAIBatchProcessor: "model_name": body.get("model"), "system_fingerprint": body.get("system_fingerprint"), } - - chat_result = ChatResult( - generations=generations, - llm_output=llm_output, - ) + + chat_result = ChatResult(generations=generations, llm_output=llm_output) chat_results.append(chat_result) - + return chat_results - + def process_batch( self, messages_list: List[List[BaseMessage]], @@ -425,7 +444,7 @@ class OpenAIBatchProcessor: ) -> List[ChatResult]: """ Complete batch processing workflow: create, poll, and retrieve results. - + Args: messages_list: List of message sequences to process in batch. description: Optional description for the batch job. @@ -433,10 +452,10 @@ class OpenAIBatchProcessor: poll_interval: Time in seconds between status checks. Uses default if None. timeout: Maximum time in seconds to wait. Uses default if None. **kwargs: Additional parameters to pass to chat completions. - + Returns: List of ChatResult objects corresponding to the original message sequences. - + Raises: BatchError: If any step of the batch processing fails. """ @@ -445,50 +464,39 @@ class OpenAIBatchProcessor: messages_list=messages_list, description=description, metadata=metadata, - **kwargs + **kwargs, ) - + # Poll until completion self.poll_batch_status( - batch_id=batch_id, - poll_interval=poll_interval, - timeout=timeout, + batch_id=batch_id, poll_interval=poll_interval, timeout=timeout ) - + # Retrieve and return results return self.retrieve_batch_results(batch_id) def create_batch_request( - messages: List[BaseMessage], - model: str, - custom_id: str, - **kwargs: Any, + messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any ) -> Dict[str, Any]: """ Create a batch request object from LangChain messages. - + Args: messages: List of LangChain messages to convert. model: The model to use for the request. custom_id: Unique identifier for this request within the batch. **kwargs: Additional parameters to pass to the chat completion. - + Returns: Dictionary in OpenAI batch request format. """ # Convert LangChain messages to OpenAI format openai_messages = [convert_message_to_dict(msg) for msg in messages] - + return { "custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", - "body": { - "model": model, - "messages": openai_messages, - **kwargs - } + "body": {"model": model, "messages": openai_messages, **kwargs}, } - - diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py index cf1e150fc56..8428c711338 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py @@ -6,20 +6,18 @@ They are designed to test the complete end-to-end batch processing workflow. import os import time -from typing import List import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.outputs import ChatResult from langchain_openai import ChatOpenAI from langchain_openai.chat_models.batch import BatchError - # Skip all tests if no API key is available pytestmark = pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY"), - reason="OPENAI_API_KEY not set, skipping integration tests" + reason="OPENAI_API_KEY not set, skipping integration tests", ) @@ -31,7 +29,7 @@ class TestBatchAPIIntegration: self.llm = ChatOpenAI( model="gpt-3.5-turbo", temperature=0.1, # Low temperature for consistent results - max_tokens=50, # Keep responses short for faster processing + max_tokens=50, # Keep responses short for faster processing ) @pytest.mark.scheduled @@ -40,7 +38,11 @@ class TestBatchAPIIntegration: # Create a small batch of simple questions messages_list = [ [HumanMessage(content="What is 2+2? Answer with just the number.")], - [HumanMessage(content="What is the capital of France? Answer with just the city name.")], + [ + HumanMessage( + content="What is the capital of France? Answer with just the city name." + ) + ], ] # Create batch job @@ -58,19 +60,21 @@ class TestBatchAPIIntegration: results = self.llm.batch_retrieve( batch_id=batch_id, poll_interval=30.0, # Poll every 30 seconds - timeout=1800.0, # 30 minute timeout + timeout=1800.0, # 30 minute timeout ) # Verify results assert len(results) == 2 assert all(isinstance(result, ChatResult) for result in results) assert all(len(result.generations) == 1 for result in results) - assert all(isinstance(result.generations[0].message, AIMessage) for result in results) + assert all( + isinstance(result.generations[0].message, AIMessage) for result in results + ) # Check that we got reasonable responses response1 = results[0].generations[0].message.content.strip() response2 = results[1].generations[0].message.content.strip() - + # Basic sanity checks (responses should contain expected content) assert "4" in response1 or "four" in response1.lower() assert "paris" in response2.lower() @@ -85,10 +89,7 @@ class TestBatchAPIIntegration: # Use batch API mode results = self.llm.batch( - inputs, - use_batch_api=True, - poll_interval=30.0, - timeout=1800.0, + inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0 ) # Verify results @@ -99,37 +100,32 @@ class TestBatchAPIIntegration: # Basic sanity checks response1 = results[0].content.strip().lower() response2 = results[1].content.strip().lower() - + assert any(char in response1 for char in ["1", "2", "3"]) assert "blue" in response2 @pytest.mark.scheduled def test_batch_method_comparison(self): """Test that batch API and standard batch produce similar results.""" - inputs = [ - [HumanMessage(content="What is 1+1? Answer with just the number.")], - ] + inputs = [[HumanMessage(content="What is 1+1? Answer with just the number.")]] # Test standard batch processing standard_results = self.llm.batch(inputs, use_batch_api=False) - + # Test batch API processing batch_api_results = self.llm.batch( - inputs, - use_batch_api=True, - poll_interval=30.0, - timeout=1800.0, + inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0 ) # Both should return similar structure assert len(standard_results) == len(batch_api_results) == 1 assert isinstance(standard_results[0], AIMessage) assert isinstance(batch_api_results[0], AIMessage) - + # Both should contain reasonable answers standard_content = standard_results[0].content.strip() batch_content = batch_api_results[0].content.strip() - + assert "2" in standard_content or "two" in standard_content.lower() assert "2" in batch_content or "two" in batch_content.lower() @@ -137,7 +133,7 @@ class TestBatchAPIIntegration: def test_batch_with_different_parameters(self): """Test batch processing with different model parameters.""" messages_list = [ - [HumanMessage(content="Write a haiku about coding. Keep it short.")], + [HumanMessage(content="Write a haiku about coding. Keep it short.")] ] # Create batch with specific parameters @@ -146,18 +142,16 @@ class TestBatchAPIIntegration: description="Integration test - parameters", metadata={"test_type": "parameters"}, temperature=0.8, # Higher temperature for creativity - max_tokens=100, # More tokens for haiku + max_tokens=100, # More tokens for haiku ) results = self.llm.batch_retrieve( - batch_id=batch_id, - poll_interval=30.0, - timeout=1800.0, + batch_id=batch_id, poll_interval=30.0, timeout=1800.0 ) assert len(results) == 1 result_content = results[0].generations[0].message.content - + # Should have some content (haiku) assert len(result_content.strip()) > 10 # Haikus typically have line breaks @@ -170,25 +164,24 @@ class TestBatchAPIIntegration: messages_list = [ [ - SystemMessage(content="You are a helpful math tutor. Answer concisely."), + SystemMessage( + content="You are a helpful math tutor. Answer concisely." + ), HumanMessage(content="What is 5 * 6?"), - ], + ] ] batch_id = self.llm.batch_create( - messages_list=messages_list, - description="Integration test - system message", + messages_list=messages_list, description="Integration test - system message" ) results = self.llm.batch_retrieve( - batch_id=batch_id, - poll_interval=30.0, - timeout=1800.0, + batch_id=batch_id, poll_interval=30.0, timeout=1800.0 ) assert len(results) == 1 result_content = results[0].generations[0].message.content.strip() - + # Should contain the answer assert "30" in result_content or "thirty" in result_content.lower() @@ -196,14 +189,9 @@ class TestBatchAPIIntegration: def test_batch_error_handling_invalid_model(self): """Test error handling with invalid model parameters.""" # Create a ChatOpenAI instance with an invalid model - invalid_llm = ChatOpenAI( - model="invalid-model-name-12345", - temperature=0.1, - ) + invalid_llm = ChatOpenAI(model="invalid-model-name-12345", temperature=0.1) - messages_list = [ - [HumanMessage(content="Hello")], - ] + messages_list = [[HumanMessage(content="Hello")]] # This should fail during batch creation or processing with pytest.raises(BatchError): @@ -216,23 +204,24 @@ class TestBatchAPIIntegration: # Test with string inputs (should be converted to HumanMessage) inputs = [ "What is the largest planet? Answer with just the planet name.", - [HumanMessage(content="What is the smallest planet? Answer with just the planet name.")], + [ + HumanMessage( + content="What is the smallest planet? Answer with just the planet name." + ) + ], ] results = self.llm.batch( - inputs, - use_batch_api=True, - poll_interval=30.0, - timeout=1800.0, + inputs, use_batch_api=True, poll_interval=30.0, timeout=1800.0 ) assert len(results) == 2 assert all(isinstance(result, AIMessage) for result in results) - + # Check for reasonable responses response1 = results[0].content.strip().lower() response2 = results[1].content.strip().lower() - + assert "jupiter" in response1 assert "mercury" in response2 @@ -248,12 +237,10 @@ class TestBatchAPIIntegration: results = self.llm.batch_retrieve(batch_id, timeout=300.0) assert results == [] - @pytest.mark.scheduled + @pytest.mark.scheduled def test_batch_metadata_preservation(self): """Test that batch metadata is properly handled.""" - messages_list = [ - [HumanMessage(content="Say 'test successful'")], - ] + messages_list = [[HumanMessage(content="Say 'test successful'")]] metadata = { "test_name": "metadata_test", @@ -270,7 +257,7 @@ class TestBatchAPIIntegration: # Retrieve results results = self.llm.batch_retrieve(batch_id, timeout=1800.0) - + assert len(results) == 1 result_content = results[0].generations[0].message.content.strip().lower() assert "test successful" in result_content @@ -281,18 +268,12 @@ class TestBatchAPIEdgeCases: def setup_method(self): """Set up test fixtures.""" - self.llm = ChatOpenAI( - model="gpt-3.5-turbo", - temperature=0.1, - max_tokens=50, - ) + self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1, max_tokens=50) @pytest.mark.scheduled def test_batch_with_very_short_timeout(self): """Test batch processing with very short timeout.""" - messages_list = [ - [HumanMessage(content="Hello")], - ] + messages_list = [[HumanMessage(content="Hello")]] batch_id = self.llm.batch_create(messages_list=messages_list) @@ -313,10 +294,8 @@ class TestBatchAPIEdgeCases: def test_batch_with_long_content(self): """Test batch processing with longer content.""" long_content = "Please summarize this text: " + "This is a test sentence. " * 20 - - messages_list = [ - [HumanMessage(content=long_content)], - ] + + messages_list = [[HumanMessage(content=long_content)]] batch_id = self.llm.batch_create( messages_list=messages_list, @@ -324,10 +303,10 @@ class TestBatchAPIEdgeCases: ) results = self.llm.batch_retrieve(batch_id, timeout=1800.0) - + assert len(results) == 1 result_content = results[0].generations[0].message.content - + # Should have some summary content assert len(result_content.strip()) > 10 @@ -353,7 +332,7 @@ class TestBatchAPIPerformance: ] start_time = time.time() - + batch_id = self.llm.batch_create( messages_list=messages_list, description="Medium batch test - 10 requests", @@ -363,7 +342,7 @@ class TestBatchAPIPerformance: results = self.llm.batch_retrieve( batch_id=batch_id, poll_interval=60.0, # Poll every minute - timeout=3600.0, # 1 hour timeout + timeout=3600.0, # 1 hour timeout ) end_time = time.time() @@ -401,20 +380,17 @@ class TestBatchAPIPerformance: # Test batch API processing time start_batch = time.time() batch_results = self.llm.batch( - messages, - use_batch_api=True, - poll_interval=30.0, - timeout=1800.0, + messages, use_batch_api=True, poll_interval=30.0, timeout=1800.0 ) batch_time = time.time() - start_batch # Verify both produce results assert len(sequential_results) == len(batch_results) == 2 - + # Log timing comparison print(f"Sequential processing: {sequential_time:.2f}s") print(f"Batch API processing: {batch_time:.2f}s") - + # Note: Batch API will typically be slower for small batches due to polling, # but should be more cost-effective for larger batches @@ -424,6 +400,7 @@ def is_openai_api_available() -> bool: """Check if OpenAI API is available and accessible.""" try: import openai + client = openai.OpenAI() # Try a simple API call to verify connectivity client.models.list() diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py index adcd26eecea..71d1644404f 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_batch.py @@ -1,12 +1,10 @@ """Test OpenAI Batch API functionality.""" import json -import time -from typing import Any, Dict, List -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_openai import ChatOpenAI @@ -42,7 +40,10 @@ class TestOpenAIBatchClient: "custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", - "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}, + "body": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + }, } ] @@ -64,7 +65,10 @@ class TestOpenAIBatchClient: "custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", - "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}, + "body": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + }, } ] @@ -76,10 +80,10 @@ class TestOpenAIBatchClient: # Mock batch status progression mock_batch_validating = MagicMock() mock_batch_validating.status = "validating" - + mock_batch_in_progress = MagicMock() mock_batch_in_progress.status = "in_progress" - + mock_batch_completed = MagicMock() mock_batch_completed.status = "completed" mock_batch_completed.output_file_id = "file_123" @@ -134,13 +138,17 @@ class TestOpenAIBatchClient: { "message": { "role": "assistant", - "content": "Hello! How can I help you?" + "content": "Hello! How can I help you?", } } ], - "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} - } - } + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + }, + }, } ] @@ -177,7 +185,7 @@ class TestOpenAIBatchProcessor: def test_create_batch_success(self): """Test successful batch creation with message conversion.""" # Mock batch client - with patch.object(self.processor, 'batch_client') as mock_batch_client: + with patch.object(self.processor, "batch_client") as mock_batch_client: mock_batch_client.create_batch.return_value = "batch_123" messages_list = [ @@ -198,7 +206,7 @@ class TestOpenAIBatchProcessor: # Verify batch requests were created correctly call_args = mock_batch_client.create_batch.call_args batch_requests = call_args[1]["batch_requests"] - + assert len(batch_requests) == 2 assert batch_requests[0]["custom_id"] == "request-0" assert batch_requests[0]["body"]["model"] == "gpt-3.5-turbo" @@ -208,7 +216,7 @@ class TestOpenAIBatchProcessor: def test_poll_batch_status_success(self): """Test successful batch status polling.""" - with patch.object(self.processor, 'batch_client') as mock_batch_client: + with patch.object(self.processor, "batch_client") as mock_batch_client: mock_batch = MagicMock() mock_batch.status = "completed" mock_batch_client.poll_batch_status.return_value = mock_batch @@ -238,13 +246,17 @@ class TestOpenAIBatchProcessor: { "message": { "role": "assistant", - "content": "2+2 equals 4." + "content": "2+2 equals 4.", } } ], - "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} - } - } + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + }, + }, }, { "id": "batch_req_124", @@ -256,35 +268,42 @@ class TestOpenAIBatchProcessor: { "message": { "role": "assistant", - "content": "The capital of France is Paris." + "content": "The capital of France is Paris.", } } ], - "usage": {"prompt_tokens": 12, "completion_tokens": 10, "total_tokens": 22} - } - } - } + "usage": { + "prompt_tokens": 12, + "completion_tokens": 10, + "total_tokens": 22, + }, + }, + }, + }, ] - with patch.object(self.processor, 'batch_client') as mock_batch_client: + with patch.object(self.processor, "batch_client") as mock_batch_client: mock_batch_client.poll_batch_status.return_value = mock_batch mock_batch_client.retrieve_batch_results.return_value = mock_results chat_results = self.processor.retrieve_batch_results("batch_123") assert len(chat_results) == 2 - + # Check first result assert isinstance(chat_results[0], ChatResult) assert len(chat_results[0].generations) == 1 assert isinstance(chat_results[0].generations[0].message, AIMessage) assert chat_results[0].generations[0].message.content == "2+2 equals 4." - + # Check second result assert isinstance(chat_results[1], ChatResult) assert len(chat_results[1].generations) == 1 assert isinstance(chat_results[1].generations[0].message, AIMessage) - assert chat_results[1].generations[0].message.content == "The capital of France is Paris." + assert ( + chat_results[1].generations[0].message.content + == "The capital of France is Paris." + ) def test_retrieve_batch_results_with_errors(self): """Test batch result retrieval with some failed requests.""" @@ -303,13 +322,17 @@ class TestOpenAIBatchProcessor: { "message": { "role": "assistant", - "content": "Success response" + "content": "Success response", } } ], - "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18} - } - } + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + }, + }, }, { "id": "batch_req_124", @@ -319,14 +342,14 @@ class TestOpenAIBatchProcessor: "body": { "error": { "message": "Invalid request", - "type": "invalid_request_error" + "type": "invalid_request_error", } - } - } - } + }, + }, + }, ] - with patch.object(self.processor, 'batch_client') as mock_batch_client: + with patch.object(self.processor, "batch_client") as mock_batch_client: mock_batch_client.poll_batch_status.return_value = mock_batch mock_batch_client.retrieve_batch_results.return_value = mock_results @@ -341,7 +364,7 @@ class TestBaseChatOpenAIBatchMethods: """Set up test fixtures.""" self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key") - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_batch_create_success(self, mock_processor_class): """Test successful batch creation.""" mock_processor = MagicMock() @@ -369,13 +392,21 @@ class TestBaseChatOpenAIBatchMethods: temperature=0.7, ) - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_batch_retrieve_success(self, mock_processor_class): """Test successful batch result retrieval.""" mock_processor = MagicMock() mock_chat_results = [ - ChatResult(generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))]), - ChatResult(generations=[ChatGeneration(message=AIMessage(content="The capital of France is Paris."))]), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="2+2 equals 4."))] + ), + ChatResult( + generations=[ + ChatGeneration( + message=AIMessage(content="The capital of France is Paris.") + ) + ] + ), ] mock_processor.retrieve_batch_results.return_value = mock_chat_results mock_processor_class.return_value = mock_processor @@ -384,22 +415,27 @@ class TestBaseChatOpenAIBatchMethods: assert len(results) == 2 assert results[0].generations[0].message.content == "2+2 equals 4." - assert results[1].generations[0].message.content == "The capital of France is Paris." - + assert ( + results[1].generations[0].message.content + == "The capital of France is Paris." + ) + mock_processor.poll_batch_status.assert_called_once_with( - batch_id="batch_123", - poll_interval=1.0, - timeout=60.0, + batch_id="batch_123", poll_interval=1.0, timeout=60.0 ) mock_processor.retrieve_batch_results.assert_called_once_with("batch_123") - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_batch_method_with_batch_api_true(self, mock_processor_class): """Test batch method with use_batch_api=True.""" mock_processor = MagicMock() mock_chat_results = [ - ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]), - ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response 1"))] + ), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response 2"))] + ), ] mock_processor.create_batch.return_value = "batch_123" mock_processor.retrieve_batch_results.return_value = mock_chat_results @@ -429,7 +465,7 @@ class TestBaseChatOpenAIBatchMethods: ] # Mock the parent class batch method - with patch.object(ChatOpenAI.__bases__[0], 'batch') as mock_super_batch: + with patch.object(ChatOpenAI.__bases__[0], "batch") as mock_super_batch: mock_super_batch.return_value = [ AIMessage(content="Response 1"), AIMessage(content="Response 2"), @@ -439,9 +475,7 @@ class TestBaseChatOpenAIBatchMethods: assert len(results) == 2 mock_super_batch.assert_called_once_with( - inputs=inputs, - config=None, - return_exceptions=False, + inputs=inputs, config=None, return_exceptions=False ) def test_convert_input_to_messages_list(self): @@ -462,7 +496,7 @@ class TestBaseChatOpenAIBatchMethods: assert isinstance(result[0], HumanMessage) assert result[0].content == "Hello" - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_batch_create_with_error_handling(self, mock_processor_class): """Test batch creation with error handling.""" mock_processor = MagicMock() @@ -474,11 +508,13 @@ class TestBaseChatOpenAIBatchMethods: with pytest.raises(BatchError, match="Batch creation failed"): self.llm.batch_create(messages_list) - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_batch_retrieve_with_error_handling(self, mock_processor_class): """Test batch retrieval with error handling.""" mock_processor = MagicMock() - mock_processor.poll_batch_status.side_effect = BatchError("Batch polling failed") + mock_processor.poll_batch_status.side_effect = BatchError( + "Batch polling failed" + ) mock_processor_class.return_value = mock_processor with pytest.raises(BatchError, match="Batch polling failed"): @@ -486,12 +522,15 @@ class TestBaseChatOpenAIBatchMethods: def test_batch_method_input_conversion(self): """Test batch method handles various input formats correctly.""" - with patch.object(self.llm, 'batch_create') as mock_create, \ - patch.object(self.llm, 'batch_retrieve') as mock_retrieve: - + with ( + patch.object(self.llm, "batch_create") as mock_create, + patch.object(self.llm, "batch_retrieve") as mock_retrieve, + ): mock_create.return_value = "batch_123" mock_retrieve.return_value = [ - ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response"))]), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response"))] + ) ] # Test with string inputs @@ -501,8 +540,8 @@ class TestBaseChatOpenAIBatchMethods: # Verify conversion happened mock_create.assert_called_once() call_args = mock_create.call_args[1] - messages_list = call_args['messages_list'] - + messages_list = call_args["messages_list"] + assert len(messages_list) == 1 assert len(messages_list[0]) == 1 assert isinstance(messages_list[0][0], HumanMessage) @@ -532,7 +571,7 @@ class TestBatchIntegrationScenarios: """Set up test fixtures.""" self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key="test-key") - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_empty_messages_list(self, mock_processor_class): """Test handling of empty messages list.""" mock_processor = MagicMock() @@ -543,16 +582,18 @@ class TestBatchIntegrationScenarios: results = self.llm.batch([], use_batch_api=True) assert results == [] - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_large_batch_processing(self, mock_processor_class): """Test processing of large batch.""" mock_processor = MagicMock() mock_processor.create_batch.return_value = "batch_123" - + # Create mock results for large batch num_requests = 100 mock_chat_results = [ - ChatResult(generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))]) + ChatResult( + generations=[ChatGeneration(message=AIMessage(content=f"Response {i}"))] + ) for i in range(num_requests) ] mock_processor.retrieve_batch_results.return_value = mock_chat_results @@ -565,14 +606,18 @@ class TestBatchIntegrationScenarios: for i, result in enumerate(results): assert result.content == f"Response {i}" - @patch('langchain_openai.chat_models.batch.OpenAIBatchProcessor') + @patch("langchain_openai.chat_models.batch.OpenAIBatchProcessor") def test_mixed_message_types(self, mock_processor_class): """Test batch processing with mixed message types.""" mock_processor = MagicMock() mock_processor.create_batch.return_value = "batch_123" mock_processor.retrieve_batch_results.return_value = [ - ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 1"))]), - ChatResult(generations=[ChatGeneration(message=AIMessage(content="Response 2"))]), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response 1"))] + ), + ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Response 2"))] + ), ] mock_processor_class.return_value = mock_processor @@ -587,12 +632,12 @@ class TestBatchIntegrationScenarios: # Verify the conversion happened correctly mock_processor.create_batch.assert_called_once() call_args = mock_processor.create_batch.call_args[1] - messages_list = call_args['messages_list'] - + messages_list = call_args["messages_list"] + # First input should be converted to HumanMessage assert isinstance(messages_list[0][0], HumanMessage) assert messages_list[0][0].content == "String input" - + # Second input should remain as is assert isinstance(messages_list[1][0], HumanMessage) assert messages_list[1][0].content == "Direct message list"