Apply patch [skip ci]

This commit is contained in:
open-swe[bot]
2025-08-11 20:17:16 +00:00
parent 3695907018
commit 02072a9473

View File

@@ -2149,6 +2149,114 @@ class BaseChatOpenAI(BaseChatModel):
)
return processor.retrieve_batch_results(batch_id)
@override
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
use_batch_api: bool = False,
**kwargs: Any,
) -> List[BaseMessage]:
"""
Batch process multiple inputs using either standard API or OpenAI Batch API.
This method provides two processing modes:
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
Args:
inputs: List of inputs to process in batch.
config: Configuration for the batch processing.
return_exceptions: Whether to return exceptions instead of raising them.
use_batch_api: If True, use OpenAI's Batch API for cost savings with polling.
If False (default), use standard parallel processing for immediate results.
**kwargs: Additional parameters to pass to the underlying model.
Returns:
List of BaseMessage objects corresponding to the inputs.
Raises:
BatchError: If batch processing fails (when use_batch_api=True).
Note:
**Cost vs Latency Tradeoff:**
- use_batch_api=False: Immediate results, standard API pricing
- use_batch_api=True: 50% cost savings, asynchronous processing with polling
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
llm = ChatOpenAI()
inputs = [
[HumanMessage(content="What is 2+2?")],
[HumanMessage(content="What is the capital of France?")],
]
# Standard processing (immediate results)
results = llm.batch(inputs)
# Batch API processing (50% cost savings, polling required)
results = llm.batch(inputs, use_batch_api=True)
"""
if use_batch_api:
# Convert inputs to messages_list format expected by batch_create
messages_list = []
for input_item in inputs:
if isinstance(input_item, list):
# Already a list of messages
messages_list.append(input_item)
else:
# Convert single input to list of messages
messages = self._convert_input_to_messages(input_item)
messages_list.append(messages)
# Create batch job and poll for results
batch_id = self.batch_create(messages_list, **kwargs)
chat_results = self.batch_retrieve(batch_id)
# Convert ChatResult objects to BaseMessage objects
return [result.generations[0].message for result in chat_results]
else:
# Use the parent class's standard batch implementation
return super().batch(
inputs=inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def _convert_input_to_messages(self, input_item: LanguageModelInput) -> List[BaseMessage]:
"""Convert various input formats to a list of BaseMessage objects."""
if isinstance(input_item, list):
# Already a list of messages
return input_item
elif isinstance(input_item, BaseMessage):
# Single message
return [input_item]
elif isinstance(input_item, str):
# String input - convert to HumanMessage
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_item)]
elif hasattr(input_item, 'to_messages'):
# PromptValue or similar
return input_item.to_messages()
else:
# Try to convert to string and then to HumanMessage
from langchain_core.messages import HumanMessage
return [HumanMessage(content=str(input_item))]