mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -2149,6 +2149,114 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
|
||||
return processor.retrieve_batch_results(batch_id)
|
||||
@override
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
use_batch_api: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
Batch process multiple inputs using either standard API or OpenAI Batch API.
|
||||
|
||||
This method provides two processing modes:
|
||||
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
|
||||
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
|
||||
|
||||
Args:
|
||||
inputs: List of inputs to process in batch.
|
||||
config: Configuration for the batch processing.
|
||||
return_exceptions: Whether to return exceptions instead of raising them.
|
||||
use_batch_api: If True, use OpenAI's Batch API for cost savings with polling.
|
||||
If False (default), use standard parallel processing for immediate results.
|
||||
**kwargs: Additional parameters to pass to the underlying model.
|
||||
|
||||
Returns:
|
||||
List of BaseMessage objects corresponding to the inputs.
|
||||
|
||||
Raises:
|
||||
BatchError: If batch processing fails (when use_batch_api=True).
|
||||
|
||||
Note:
|
||||
**Cost vs Latency Tradeoff:**
|
||||
- use_batch_api=False: Immediate results, standard API pricing
|
||||
- use_batch_api=True: 50% cost savings, asynchronous processing with polling
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
llm = ChatOpenAI()
|
||||
inputs = [
|
||||
[HumanMessage(content="What is 2+2?")],
|
||||
[HumanMessage(content="What is the capital of France?")],
|
||||
]
|
||||
|
||||
# Standard processing (immediate results)
|
||||
results = llm.batch(inputs)
|
||||
|
||||
# Batch API processing (50% cost savings, polling required)
|
||||
results = llm.batch(inputs, use_batch_api=True)
|
||||
"""
|
||||
if use_batch_api:
|
||||
# Convert inputs to messages_list format expected by batch_create
|
||||
messages_list = []
|
||||
for input_item in inputs:
|
||||
if isinstance(input_item, list):
|
||||
# Already a list of messages
|
||||
messages_list.append(input_item)
|
||||
else:
|
||||
# Convert single input to list of messages
|
||||
messages = self._convert_input_to_messages(input_item)
|
||||
messages_list.append(messages)
|
||||
|
||||
# Create batch job and poll for results
|
||||
batch_id = self.batch_create(messages_list, **kwargs)
|
||||
chat_results = self.batch_retrieve(batch_id)
|
||||
|
||||
# Convert ChatResult objects to BaseMessage objects
|
||||
return [result.generations[0].message for result in chat_results]
|
||||
else:
|
||||
# Use the parent class's standard batch implementation
|
||||
return super().batch(
|
||||
inputs=inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _convert_input_to_messages(self, input_item: LanguageModelInput) -> List[BaseMessage]:
|
||||
"""Convert various input formats to a list of BaseMessage objects."""
|
||||
if isinstance(input_item, list):
|
||||
# Already a list of messages
|
||||
return input_item
|
||||
elif isinstance(input_item, BaseMessage):
|
||||
# Single message
|
||||
return [input_item]
|
||||
elif isinstance(input_item, str):
|
||||
# String input - convert to HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
return [HumanMessage(content=input_item)]
|
||||
elif hasattr(input_item, 'to_messages'):
|
||||
# PromptValue or similar
|
||||
return input_item.to_messages()
|
||||
else:
|
||||
# Try to convert to string and then to HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
return [HumanMessage(content=str(input_item))]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user