mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
Apply patch [skip ci]
This commit is contained in:
273
libs/partners/openai/langchain_openai/chat_models/batch.py
Normal file
273
libs/partners/openai/langchain_openai/chat_models/batch.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""OpenAI Batch API client wrapper for LangChain integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
|
||||
from langchain_openai.chat_models._compat import convert_message_to_dict
|
||||
|
||||
|
||||
class BatchStatus(str, Enum):
|
||||
"""OpenAI Batch API status values."""
|
||||
|
||||
VALIDATING = "validating"
|
||||
FAILED = "failed"
|
||||
IN_PROGRESS = "in_progress"
|
||||
FINALIZING = "finalizing"
|
||||
COMPLETED = "completed"
|
||||
EXPIRED = "expired"
|
||||
CANCELLING = "cancelling"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class BatchError(Exception):
|
||||
"""Exception raised when batch processing fails."""
|
||||
|
||||
def __init__(self, message: str, batch_id: Optional[str] = None, status: Optional[str] = None):
|
||||
super().__init__(message)
|
||||
self.batch_id = batch_id
|
||||
self.status = status
|
||||
|
||||
|
||||
class OpenAIBatchClient:
|
||||
"""
|
||||
OpenAI Batch API client wrapper that handles batch creation, status polling,
|
||||
and result retrieval.
|
||||
|
||||
This class provides a high-level interface to OpenAI's Batch API, which offers
|
||||
50% cost savings compared to the standard API in exchange for asynchronous processing.
|
||||
"""
|
||||
|
||||
def __init__(self, client: openai.OpenAI):
|
||||
"""
|
||||
Initialize the batch client.
|
||||
|
||||
Args:
|
||||
client: OpenAI client instance to use for API calls.
|
||||
"""
|
||||
self.client = client
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
requests: List[Dict[str, Any]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new batch job with the OpenAI Batch API.
|
||||
|
||||
Args:
|
||||
requests: List of request objects in OpenAI batch format.
|
||||
description: Optional description for the batch job.
|
||||
metadata: Optional metadata to attach to the batch job.
|
||||
|
||||
Returns:
|
||||
The batch ID for tracking the job.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch creation fails.
|
||||
"""
|
||||
try:
|
||||
# Create JSONL content for the batch
|
||||
jsonl_content = "\n".join(json.dumps(req) for req in requests)
|
||||
|
||||
# Upload the batch file
|
||||
file_response = self.client.files.create(
|
||||
file=jsonl_content.encode('utf-8'),
|
||||
purpose="batch"
|
||||
)
|
||||
|
||||
# Create the batch job
|
||||
batch_response = self.client.batches.create(
|
||||
input_file_id=file_response.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
return batch_response.id
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to create batch: {e}") from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error creating batch: {e}") from e
|
||||
|
||||
def retrieve_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve batch information by ID.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve.
|
||||
|
||||
Returns:
|
||||
Dictionary containing batch information including status.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch retrieval fails.
|
||||
"""
|
||||
try:
|
||||
batch = self.client.batches.retrieve(batch_id)
|
||||
return {
|
||||
"id": batch.id,
|
||||
"status": batch.status,
|
||||
"created_at": batch.created_at,
|
||||
"completed_at": getattr(batch, 'completed_at', None),
|
||||
"failed_at": getattr(batch, 'failed_at', None),
|
||||
"expired_at": getattr(batch, 'expired_at', None),
|
||||
"request_counts": getattr(batch, 'request_counts', {}),
|
||||
"metadata": getattr(batch, 'metadata', {}),
|
||||
"errors": getattr(batch, 'errors', None),
|
||||
"output_file_id": getattr(batch, 'output_file_id', None),
|
||||
"error_file_id": getattr(batch, 'error_file_id', None),
|
||||
}
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to retrieve batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error retrieving batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
def poll_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
poll_interval: float = 10.0,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to poll.
|
||||
poll_interval: Time in seconds between status checks.
|
||||
timeout: Maximum time in seconds to wait. None for no timeout.
|
||||
|
||||
Returns:
|
||||
Final batch information when completed.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch fails or times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
batch_info = self.retrieve_batch(batch_id)
|
||||
status = batch_info["status"]
|
||||
|
||||
if status == BatchStatus.COMPLETED:
|
||||
return batch_info
|
||||
elif status in [BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED]:
|
||||
error_msg = f"Batch {batch_id} failed with status: {status}"
|
||||
if batch_info.get("errors"):
|
||||
error_msg += f". Errors: {batch_info['errors']}"
|
||||
raise BatchError(error_msg, batch_id=batch_id, status=status)
|
||||
|
||||
# Check timeout
|
||||
if timeout and (time.time() - start_time) > timeout:
|
||||
raise BatchError(
|
||||
f"Batch {batch_id} timed out after {timeout} seconds. Current status: {status}",
|
||||
batch_id=batch_id,
|
||||
status=status
|
||||
)
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve results from a completed batch.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve results for.
|
||||
|
||||
Returns:
|
||||
List of result objects from the batch.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch is not completed or result retrieval fails.
|
||||
"""
|
||||
try:
|
||||
batch_info = self.retrieve_batch(batch_id)
|
||||
|
||||
if batch_info["status"] != BatchStatus.COMPLETED:
|
||||
raise BatchError(
|
||||
f"Batch {batch_id} is not completed. Current status: {batch_info['status']}",
|
||||
batch_id=batch_id,
|
||||
status=batch_info["status"]
|
||||
)
|
||||
|
||||
output_file_id = batch_info.get("output_file_id")
|
||||
if not output_file_id:
|
||||
raise BatchError(f"No output file found for batch {batch_id}", batch_id=batch_id)
|
||||
|
||||
# Download and parse the results file
|
||||
file_content = self.client.files.content(output_file_id)
|
||||
results = []
|
||||
|
||||
for line in file_content.text.strip().split('\n'):
|
||||
if line.strip():
|
||||
results.append(json.loads(line))
|
||||
|
||||
return results
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to retrieve results for batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error retrieving results for batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
def cancel_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Cancel a batch job.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to cancel.
|
||||
|
||||
Returns:
|
||||
Updated batch information after cancellation.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch cancellation fails.
|
||||
"""
|
||||
try:
|
||||
batch = self.client.batches.cancel(batch_id)
|
||||
return self.retrieve_batch(batch_id)
|
||||
except openai.OpenAIError as e:
|
||||
raise BatchError(f"Failed to cancel batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
|
||||
def create_batch_request(
|
||||
messages: List[BaseMessage],
|
||||
model: str,
|
||||
custom_id: str,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a batch request object from LangChain messages.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages to convert.
|
||||
model: The model to use for the request.
|
||||
custom_id: Unique identifier for this request within the batch.
|
||||
**kwargs: Additional parameters to pass to the chat completion.
|
||||
|
||||
Returns:
|
||||
Dictionary in OpenAI batch request format.
|
||||
"""
|
||||
# Convert LangChain messages to OpenAI format
|
||||
openai_messages = [convert_message_to_dict(msg) for msg in messages]
|
||||
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": model,
|
||||
"messages": openai_messages,
|
||||
**kwargs
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user