Apply patch [skip ci]

This commit is contained in:
open-swe[bot]
2025-08-11 20:08:43 +00:00
parent 1a126d67ef
commit d2d9918386

View File

@@ -241,6 +241,224 @@ class OpenAIBatchClient:
raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e
class OpenAIBatchProcessor:
"""
High-level processor for managing OpenAI Batch API lifecycle with LangChain integration.
This class handles the complete batch processing workflow:
1. Converts LangChain messages to OpenAI batch format
2. Creates batch jobs using the OpenAI Batch API
3. Polls for completion with configurable intervals
4. Converts results back to LangChain format
"""
def __init__(
self,
client: openai.OpenAI,
model: str,
poll_interval: float = 10.0,
timeout: Optional[float] = None,
):
"""
Initialize the batch processor.
Args:
client: OpenAI client instance to use for API calls.
model: The model to use for batch requests.
poll_interval: Default time in seconds between status checks.
timeout: Default maximum time in seconds to wait for completion.
"""
self.batch_client = OpenAIBatchClient(client)
self.model = model
self.poll_interval = poll_interval
self.timeout = timeout
def create_batch(
self,
messages_list: List[List[BaseMessage]],
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> str:
"""
Create a batch job from a list of LangChain message sequences.
Args:
messages_list: List of message sequences to process in batch.
description: Optional description for the batch job.
metadata: Optional metadata to attach to the batch job.
**kwargs: Additional parameters to pass to chat completions.
Returns:
The batch ID for tracking the job.
Raises:
BatchError: If batch creation fails.
"""
# Convert LangChain messages to batch requests
requests = []
for i, messages in enumerate(messages_list):
custom_id = f"request_{i}_{uuid4().hex[:8]}"
request = create_batch_request(
messages=messages,
model=self.model,
custom_id=custom_id,
**kwargs
)
requests.append(request)
return self.batch_client.create_batch(
requests=requests,
description=description,
metadata=metadata,
)
def poll_batch_status(
self,
batch_id: str,
poll_interval: Optional[float] = None,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
"""
Poll batch status until completion or failure.
Args:
batch_id: The batch ID to poll.
poll_interval: Time in seconds between status checks. Uses default if None.
timeout: Maximum time in seconds to wait. Uses default if None.
Returns:
Final batch information when completed.
Raises:
BatchError: If batch fails or times out.
"""
return self.batch_client.poll_batch_status(
batch_id=batch_id,
poll_interval=poll_interval or self.poll_interval,
timeout=timeout or self.timeout,
)
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
"""
Retrieve and convert batch results to LangChain format.
Args:
batch_id: The batch ID to retrieve results for.
Returns:
List of ChatResult objects corresponding to the original message sequences.
Raises:
BatchError: If batch is not completed or result retrieval fails.
"""
# Get raw results from OpenAI
raw_results = self.batch_client.retrieve_batch_results(batch_id)
# Sort results by custom_id to maintain order
raw_results.sort(key=lambda x: x.get("custom_id", ""))
# Convert to LangChain ChatResult format
chat_results = []
for result in raw_results:
if result.get("error"):
# Handle individual request errors
error_msg = f"Request failed: {result['error']}"
raise BatchError(error_msg, batch_id=batch_id)
response = result.get("response", {})
if not response:
raise BatchError(f"No response found in result: {result}", batch_id=batch_id)
body = response.get("body", {})
choices = body.get("choices", [])
if not choices:
raise BatchError(f"No choices found in response: {body}", batch_id=batch_id)
# Convert OpenAI response to LangChain format
generations = []
for choice in choices:
message_dict = choice.get("message", {})
if not message_dict:
continue
# Convert OpenAI message dict to LangChain message
message = convert_dict_to_message(message_dict)
# Create ChatGeneration with metadata
generation_info = {
"finish_reason": choice.get("finish_reason"),
"logprobs": choice.get("logprobs"),
}
generation = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(generation)
# Create ChatResult with usage information
usage = body.get("usage", {})
llm_output = {
"token_usage": usage,
"model_name": body.get("model"),
"system_fingerprint": body.get("system_fingerprint"),
}
chat_result = ChatResult(
generations=generations,
llm_output=llm_output,
)
chat_results.append(chat_result)
return chat_results
def process_batch(
self,
messages_list: List[List[BaseMessage]],
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
poll_interval: Optional[float] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[ChatResult]:
"""
Complete batch processing workflow: create, poll, and retrieve results.
Args:
messages_list: List of message sequences to process in batch.
description: Optional description for the batch job.
metadata: Optional metadata to attach to the batch job.
poll_interval: Time in seconds between status checks. Uses default if None.
timeout: Maximum time in seconds to wait. Uses default if None.
**kwargs: Additional parameters to pass to chat completions.
Returns:
List of ChatResult objects corresponding to the original message sequences.
Raises:
BatchError: If any step of the batch processing fails.
"""
# Create the batch
batch_id = self.create_batch(
messages_list=messages_list,
description=description,
metadata=metadata,
**kwargs
)
# Poll until completion
self.poll_batch_status(
batch_id=batch_id,
poll_interval=poll_interval,
timeout=timeout,
)
# Retrieve and return results
return self.retrieve_batch_results(batch_id)
def create_batch_request(
messages: List[BaseMessage],
model: str,
@@ -273,3 +491,4 @@ def create_batch_request(
}
}