From d2d99183861ed18daefb3fb272a33cab59ab609d Mon Sep 17 00:00:00 2001 From: "open-swe[bot]" Date: Mon, 11 Aug 2025 20:08:43 +0000 Subject: [PATCH] Apply patch [skip ci] --- .../langchain_openai/chat_models/batch.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) diff --git a/libs/partners/openai/langchain_openai/chat_models/batch.py b/libs/partners/openai/langchain_openai/chat_models/batch.py index f377c8b1aef..ba369a52faa 100644 --- a/libs/partners/openai/langchain_openai/chat_models/batch.py +++ b/libs/partners/openai/langchain_openai/chat_models/batch.py @@ -241,6 +241,224 @@ class OpenAIBatchClient: raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e +class OpenAIBatchProcessor: + """ + High-level processor for managing OpenAI Batch API lifecycle with LangChain integration. + + This class handles the complete batch processing workflow: + 1. Converts LangChain messages to OpenAI batch format + 2. Creates batch jobs using the OpenAI Batch API + 3. Polls for completion with configurable intervals + 4. Converts results back to LangChain format + """ + + def __init__( + self, + client: openai.OpenAI, + model: str, + poll_interval: float = 10.0, + timeout: Optional[float] = None, + ): + """ + Initialize the batch processor. + + Args: + client: OpenAI client instance to use for API calls. + model: The model to use for batch requests. + poll_interval: Default time in seconds between status checks. + timeout: Default maximum time in seconds to wait for completion. + """ + self.batch_client = OpenAIBatchClient(client) + self.model = model + self.poll_interval = poll_interval + self.timeout = timeout + + def create_batch( + self, + messages_list: List[List[BaseMessage]], + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Create a batch job from a list of LangChain message sequences. + + Args: + messages_list: List of message sequences to process in batch. + description: Optional description for the batch job. + metadata: Optional metadata to attach to the batch job. + **kwargs: Additional parameters to pass to chat completions. + + Returns: + The batch ID for tracking the job. + + Raises: + BatchError: If batch creation fails. + """ + # Convert LangChain messages to batch requests + requests = [] + for i, messages in enumerate(messages_list): + custom_id = f"request_{i}_{uuid4().hex[:8]}" + request = create_batch_request( + messages=messages, + model=self.model, + custom_id=custom_id, + **kwargs + ) + requests.append(request) + + return self.batch_client.create_batch( + requests=requests, + description=description, + metadata=metadata, + ) + + def poll_batch_status( + self, + batch_id: str, + poll_interval: Optional[float] = None, + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + """ + Poll batch status until completion or failure. + + Args: + batch_id: The batch ID to poll. + poll_interval: Time in seconds between status checks. Uses default if None. + timeout: Maximum time in seconds to wait. Uses default if None. + + Returns: + Final batch information when completed. + + Raises: + BatchError: If batch fails or times out. + """ + return self.batch_client.poll_batch_status( + batch_id=batch_id, + poll_interval=poll_interval or self.poll_interval, + timeout=timeout or self.timeout, + ) + + def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]: + """ + Retrieve and convert batch results to LangChain format. + + Args: + batch_id: The batch ID to retrieve results for. + + Returns: + List of ChatResult objects corresponding to the original message sequences. + + Raises: + BatchError: If batch is not completed or result retrieval fails. + """ + # Get raw results from OpenAI + raw_results = self.batch_client.retrieve_batch_results(batch_id) + + # Sort results by custom_id to maintain order + raw_results.sort(key=lambda x: x.get("custom_id", "")) + + # Convert to LangChain ChatResult format + chat_results = [] + for result in raw_results: + if result.get("error"): + # Handle individual request errors + error_msg = f"Request failed: {result['error']}" + raise BatchError(error_msg, batch_id=batch_id) + + response = result.get("response", {}) + if not response: + raise BatchError(f"No response found in result: {result}", batch_id=batch_id) + + body = response.get("body", {}) + choices = body.get("choices", []) + + if not choices: + raise BatchError(f"No choices found in response: {body}", batch_id=batch_id) + + # Convert OpenAI response to LangChain format + generations = [] + for choice in choices: + message_dict = choice.get("message", {}) + if not message_dict: + continue + + # Convert OpenAI message dict to LangChain message + message = convert_dict_to_message(message_dict) + + # Create ChatGeneration with metadata + generation_info = { + "finish_reason": choice.get("finish_reason"), + "logprobs": choice.get("logprobs"), + } + + generation = ChatGeneration( + message=message, + generation_info=generation_info, + ) + generations.append(generation) + + # Create ChatResult with usage information + usage = body.get("usage", {}) + llm_output = { + "token_usage": usage, + "model_name": body.get("model"), + "system_fingerprint": body.get("system_fingerprint"), + } + + chat_result = ChatResult( + generations=generations, + llm_output=llm_output, + ) + chat_results.append(chat_result) + + return chat_results + + def process_batch( + self, + messages_list: List[List[BaseMessage]], + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + poll_interval: Optional[float] = None, + timeout: Optional[float] = None, + **kwargs: Any, + ) -> List[ChatResult]: + """ + Complete batch processing workflow: create, poll, and retrieve results. + + Args: + messages_list: List of message sequences to process in batch. + description: Optional description for the batch job. + metadata: Optional metadata to attach to the batch job. + poll_interval: Time in seconds between status checks. Uses default if None. + timeout: Maximum time in seconds to wait. Uses default if None. + **kwargs: Additional parameters to pass to chat completions. + + Returns: + List of ChatResult objects corresponding to the original message sequences. + + Raises: + BatchError: If any step of the batch processing fails. + """ + # Create the batch + batch_id = self.create_batch( + messages_list=messages_list, + description=description, + metadata=metadata, + **kwargs + ) + + # Poll until completion + self.poll_batch_status( + batch_id=batch_id, + poll_interval=poll_interval, + timeout=timeout, + ) + + # Retrieve and return results + return self.retrieve_batch_results(batch_id) + + def create_batch_request( messages: List[BaseMessage], model: str, @@ -273,3 +491,4 @@ def create_batch_request( } } +