mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 05:47:54 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -241,6 +241,224 @@ class OpenAIBatchClient:
|
||||
raise BatchError(f"Unexpected error cancelling batch {batch_id}: {e}", batch_id=batch_id) from e
|
||||
|
||||
|
||||
class OpenAIBatchProcessor:
|
||||
"""
|
||||
High-level processor for managing OpenAI Batch API lifecycle with LangChain integration.
|
||||
|
||||
This class handles the complete batch processing workflow:
|
||||
1. Converts LangChain messages to OpenAI batch format
|
||||
2. Creates batch jobs using the OpenAI Batch API
|
||||
3. Polls for completion with configurable intervals
|
||||
4. Converts results back to LangChain format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
poll_interval: float = 10.0,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch processor.
|
||||
|
||||
Args:
|
||||
client: OpenAI client instance to use for API calls.
|
||||
model: The model to use for batch requests.
|
||||
poll_interval: Default time in seconds between status checks.
|
||||
timeout: Default maximum time in seconds to wait for completion.
|
||||
"""
|
||||
self.batch_client = OpenAIBatchClient(client)
|
||||
self.model = model
|
||||
self.poll_interval = poll_interval
|
||||
self.timeout = timeout
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Create a batch job from a list of LangChain message sequences.
|
||||
|
||||
Args:
|
||||
messages_list: List of message sequences to process in batch.
|
||||
description: Optional description for the batch job.
|
||||
metadata: Optional metadata to attach to the batch job.
|
||||
**kwargs: Additional parameters to pass to chat completions.
|
||||
|
||||
Returns:
|
||||
The batch ID for tracking the job.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch creation fails.
|
||||
"""
|
||||
# Convert LangChain messages to batch requests
|
||||
requests = []
|
||||
for i, messages in enumerate(messages_list):
|
||||
custom_id = f"request_{i}_{uuid4().hex[:8]}"
|
||||
request = create_batch_request(
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
custom_id=custom_id,
|
||||
**kwargs
|
||||
)
|
||||
requests.append(request)
|
||||
|
||||
return self.batch_client.create_batch(
|
||||
requests=requests,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def poll_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
poll_interval: Optional[float] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to poll.
|
||||
poll_interval: Time in seconds between status checks. Uses default if None.
|
||||
timeout: Maximum time in seconds to wait. Uses default if None.
|
||||
|
||||
Returns:
|
||||
Final batch information when completed.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch fails or times out.
|
||||
"""
|
||||
return self.batch_client.poll_batch_status(
|
||||
batch_id=batch_id,
|
||||
poll_interval=poll_interval or self.poll_interval,
|
||||
timeout=timeout or self.timeout,
|
||||
)
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
|
||||
"""
|
||||
Retrieve and convert batch results to LangChain format.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve results for.
|
||||
|
||||
Returns:
|
||||
List of ChatResult objects corresponding to the original message sequences.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch is not completed or result retrieval fails.
|
||||
"""
|
||||
# Get raw results from OpenAI
|
||||
raw_results = self.batch_client.retrieve_batch_results(batch_id)
|
||||
|
||||
# Sort results by custom_id to maintain order
|
||||
raw_results.sort(key=lambda x: x.get("custom_id", ""))
|
||||
|
||||
# Convert to LangChain ChatResult format
|
||||
chat_results = []
|
||||
for result in raw_results:
|
||||
if result.get("error"):
|
||||
# Handle individual request errors
|
||||
error_msg = f"Request failed: {result['error']}"
|
||||
raise BatchError(error_msg, batch_id=batch_id)
|
||||
|
||||
response = result.get("response", {})
|
||||
if not response:
|
||||
raise BatchError(f"No response found in result: {result}", batch_id=batch_id)
|
||||
|
||||
body = response.get("body", {})
|
||||
choices = body.get("choices", [])
|
||||
|
||||
if not choices:
|
||||
raise BatchError(f"No choices found in response: {body}", batch_id=batch_id)
|
||||
|
||||
# Convert OpenAI response to LangChain format
|
||||
generations = []
|
||||
for choice in choices:
|
||||
message_dict = choice.get("message", {})
|
||||
if not message_dict:
|
||||
continue
|
||||
|
||||
# Convert OpenAI message dict to LangChain message
|
||||
message = convert_dict_to_message(message_dict)
|
||||
|
||||
# Create ChatGeneration with metadata
|
||||
generation_info = {
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
"logprobs": choice.get("logprobs"),
|
||||
}
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
generations.append(generation)
|
||||
|
||||
# Create ChatResult with usage information
|
||||
usage = body.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": usage,
|
||||
"model_name": body.get("model"),
|
||||
"system_fingerprint": body.get("system_fingerprint"),
|
||||
}
|
||||
|
||||
chat_result = ChatResult(
|
||||
generations=generations,
|
||||
llm_output=llm_output,
|
||||
)
|
||||
chat_results.append(chat_result)
|
||||
|
||||
return chat_results
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
poll_interval: Optional[float] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[ChatResult]:
|
||||
"""
|
||||
Complete batch processing workflow: create, poll, and retrieve results.
|
||||
|
||||
Args:
|
||||
messages_list: List of message sequences to process in batch.
|
||||
description: Optional description for the batch job.
|
||||
metadata: Optional metadata to attach to the batch job.
|
||||
poll_interval: Time in seconds between status checks. Uses default if None.
|
||||
timeout: Maximum time in seconds to wait. Uses default if None.
|
||||
**kwargs: Additional parameters to pass to chat completions.
|
||||
|
||||
Returns:
|
||||
List of ChatResult objects corresponding to the original message sequences.
|
||||
|
||||
Raises:
|
||||
BatchError: If any step of the batch processing fails.
|
||||
"""
|
||||
# Create the batch
|
||||
batch_id = self.create_batch(
|
||||
messages_list=messages_list,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Poll until completion
|
||||
self.poll_batch_status(
|
||||
batch_id=batch_id,
|
||||
poll_interval=poll_interval,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Retrieve and return results
|
||||
return self.retrieve_batch_results(batch_id)
|
||||
|
||||
|
||||
def create_batch_request(
|
||||
messages: List[BaseMessage],
|
||||
model: str,
|
||||
@@ -273,3 +491,4 @@ def create_batch_request(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user