mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
167 lines
5.7 KiB
Python
167 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Script to add batch_create and batch_retrieve methods to BaseChatOpenAI class."""
|
|
|
|
import re
|
|
|
|
def add_batch_methods():
|
|
file_path = '/home/daytona/langchain/libs/partners/openai/langchain_openai/chat_models/base.py'
|
|
|
|
# Read the base.py file
|
|
with open(file_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Find the location to insert the methods (before _get_generation_chunk_from_completion)
|
|
pattern = r'(\s+)def _get_generation_chunk_from_completion\('
|
|
match = re.search(pattern, content)
|
|
|
|
if not match:
|
|
print("Could not find insertion point")
|
|
return False
|
|
|
|
indent = match.group(1)
|
|
insert_pos = match.start()
|
|
|
|
# Define the methods to insert
|
|
methods = f'''
|
|
def batch_create(
|
|
self,
|
|
messages_list: List[List[BaseMessage]],
|
|
*,
|
|
description: Optional[str] = None,
|
|
metadata: Optional[Dict[str, str]] = None,
|
|
poll_interval: float = 10.0,
|
|
timeout: Optional[float] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""
|
|
Create a batch job using OpenAI's Batch API for asynchronous processing.
|
|
|
|
This method provides 50% cost savings compared to the standard API in exchange
|
|
for asynchronous processing with polling for results.
|
|
|
|
Args:
|
|
messages_list: List of message sequences to process in batch.
|
|
description: Optional description for the batch job.
|
|
metadata: Optional metadata to attach to the batch job.
|
|
poll_interval: Default time in seconds between status checks when polling.
|
|
timeout: Default maximum time in seconds to wait for completion.
|
|
**kwargs: Additional parameters to pass to chat completions.
|
|
|
|
Returns:
|
|
The batch ID for tracking the asynchronous job.
|
|
|
|
Raises:
|
|
BatchError: If batch creation fails.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
llm = ChatOpenAI()
|
|
messages_list = [
|
|
[HumanMessage(content="What is 2+2?")],
|
|
[HumanMessage(content="What is the capital of France?")],
|
|
]
|
|
|
|
# Create batch job (50% cost savings)
|
|
batch_id = llm.batch_create(messages_list)
|
|
|
|
# Later, retrieve results
|
|
results = llm.batch_retrieve(batch_id)
|
|
"""
|
|
# Import here to avoid circular imports
|
|
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
|
|
|
|
# Create batch processor with current model settings
|
|
processor = OpenAIBatchProcessor(
|
|
client=self.root_client,
|
|
model=self.model_name,
|
|
poll_interval=poll_interval,
|
|
timeout=timeout,
|
|
)
|
|
|
|
# Filter and prepare kwargs for batch processing
|
|
batch_kwargs = self._get_invocation_params(**kwargs)
|
|
# Remove model from kwargs since it's handled by the processor
|
|
batch_kwargs.pop("model", None)
|
|
|
|
return processor.create_batch(
|
|
messages_list=messages_list,
|
|
description=description,
|
|
metadata=metadata,
|
|
**batch_kwargs,
|
|
)
|
|
|
|
def batch_retrieve(
|
|
self,
|
|
batch_id: str,
|
|
*,
|
|
poll_interval: Optional[float] = None,
|
|
timeout: Optional[float] = None,
|
|
) -> List[ChatResult]:
|
|
"""
|
|
Retrieve results from a batch job, polling until completion if necessary.
|
|
|
|
This method will poll the batch status until completion and return the results
|
|
converted to LangChain ChatResult format.
|
|
|
|
Args:
|
|
batch_id: The batch ID returned from batch_create().
|
|
poll_interval: Time in seconds between status checks. Uses default if None.
|
|
timeout: Maximum time in seconds to wait. Uses default if None.
|
|
|
|
Returns:
|
|
List of ChatResult objects corresponding to the original message sequences.
|
|
|
|
Raises:
|
|
BatchError: If batch retrieval fails, times out, or batch job failed.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
# After creating a batch job
|
|
batch_id = llm.batch_create(messages_list)
|
|
|
|
# Retrieve results (will poll until completion)
|
|
results = llm.batch_retrieve(batch_id)
|
|
|
|
for result in results:
|
|
print(result.generations[0].message.content)
|
|
"""
|
|
# Import here to avoid circular imports
|
|
from langchain_openai.chat_models.batch import OpenAIBatchProcessor
|
|
|
|
# Create batch processor with current model settings
|
|
processor = OpenAIBatchProcessor(
|
|
client=self.root_client,
|
|
model=self.model_name,
|
|
poll_interval=poll_interval or 10.0,
|
|
timeout=timeout,
|
|
)
|
|
|
|
# Poll for completion and retrieve results
|
|
processor.poll_batch_status(
|
|
batch_id=batch_id,
|
|
poll_interval=poll_interval,
|
|
timeout=timeout,
|
|
)
|
|
|
|
return processor.retrieve_batch_results(batch_id)
|
|
|
|
{indent}'''
|
|
|
|
# Insert the methods
|
|
new_content = content[:insert_pos] + methods + content[insert_pos:]
|
|
|
|
# Write back to file
|
|
with open(file_path, 'w') as f:
|
|
f.write(new_content)
|
|
|
|
print('Successfully added batch_create and batch_retrieve methods to BaseChatOpenAI class')
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
add_batch_methods()
|