diff --git a/add_batch_methods.py b/add_batch_methods.py new file mode 100644 index 00000000000..8d9b3bec54e --- /dev/null +++ b/add_batch_methods.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +"""Script to add batch_create and batch_retrieve methods to BaseChatOpenAI class.""" + +import re + +def add_batch_methods(): + file_path = '/home/daytona/langchain/libs/partners/openai/langchain_openai/chat_models/base.py' + + # Read the base.py file + with open(file_path, 'r') as f: + content = f.read() + + # Find the location to insert the methods (before _get_generation_chunk_from_completion) + pattern = r'(\s+)def _get_generation_chunk_from_completion\(' + match = re.search(pattern, content) + + if not match: + print("Could not find insertion point") + return False + + indent = match.group(1) + insert_pos = match.start() + + # Define the methods to insert + methods = f''' + def batch_create( + self, + messages_list: List[List[BaseMessage]], + *, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + poll_interval: float = 10.0, + timeout: Optional[float] = None, + **kwargs: Any, + ) -> str: + """ + Create a batch job using OpenAI's Batch API for asynchronous processing. + + This method provides 50% cost savings compared to the standard API in exchange + for asynchronous processing with polling for results. + + Args: + messages_list: List of message sequences to process in batch. + description: Optional description for the batch job. + metadata: Optional metadata to attach to the batch job. + poll_interval: Default time in seconds between status checks when polling. + timeout: Default maximum time in seconds to wait for completion. + **kwargs: Additional parameters to pass to chat completions. + + Returns: + The batch ID for tracking the asynchronous job. + + Raises: + BatchError: If batch creation fails. + + Example: + .. code-block:: python + + from langchain_openai import ChatOpenAI + from langchain_core.messages import HumanMessage + + llm = ChatOpenAI() + messages_list = [ + [HumanMessage(content="What is 2+2?")], + [HumanMessage(content="What is the capital of France?")], + ] + + # Create batch job (50% cost savings) + batch_id = llm.batch_create(messages_list) + + # Later, retrieve results + results = llm.batch_retrieve(batch_id) + """ + # Import here to avoid circular imports + from langchain_openai.chat_models.batch import OpenAIBatchProcessor + + # Create batch processor with current model settings + processor = OpenAIBatchProcessor( + client=self.root_client, + model=self.model_name, + poll_interval=poll_interval, + timeout=timeout, + ) + + # Filter and prepare kwargs for batch processing + batch_kwargs = self._get_invocation_params(**kwargs) + # Remove model from kwargs since it's handled by the processor + batch_kwargs.pop("model", None) + + return processor.create_batch( + messages_list=messages_list, + description=description, + metadata=metadata, + **batch_kwargs, + ) + + def batch_retrieve( + self, + batch_id: str, + *, + poll_interval: Optional[float] = None, + timeout: Optional[float] = None, + ) -> List[ChatResult]: + """ + Retrieve results from a batch job, polling until completion if necessary. + + This method will poll the batch status until completion and return the results + converted to LangChain ChatResult format. + + Args: + batch_id: The batch ID returned from batch_create(). + poll_interval: Time in seconds between status checks. Uses default if None. + timeout: Maximum time in seconds to wait. Uses default if None. + + Returns: + List of ChatResult objects corresponding to the original message sequences. + + Raises: + BatchError: If batch retrieval fails, times out, or batch job failed. + + Example: + .. code-block:: python + + # After creating a batch job + batch_id = llm.batch_create(messages_list) + + # Retrieve results (will poll until completion) + results = llm.batch_retrieve(batch_id) + + for result in results: + print(result.generations[0].message.content) + """ + # Import here to avoid circular imports + from langchain_openai.chat_models.batch import OpenAIBatchProcessor + + # Create batch processor with current model settings + processor = OpenAIBatchProcessor( + client=self.root_client, + model=self.model_name, + poll_interval=poll_interval or 10.0, + timeout=timeout, + ) + + # Poll for completion and retrieve results + processor.poll_batch_status( + batch_id=batch_id, + poll_interval=poll_interval, + timeout=timeout, + ) + + return processor.retrieve_batch_results(batch_id) + +{indent}''' + + # Insert the methods + new_content = content[:insert_pos] + methods + content[insert_pos:] + + # Write back to file + with open(file_path, 'w') as f: + f.write(new_content) + + print('Successfully added batch_create and batch_retrieve methods to BaseChatOpenAI class') + return True + +if __name__ == "__main__": + add_batch_methods()