diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 785b2ce9c24..f996373bec3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2276,18 +2276,19 @@ class BaseChatOpenAI(BaseChatModel): @override def batch( self, - inputs: List[LanguageModelInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[LanguageModelInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, use_batch_api: bool = False, **kwargs: Any, - ) -> List[BaseMessage]: + ) -> list[BaseMessage]: """ Batch process multiple inputs using either standard API or OpenAI Batch API. This method provides two processing modes: - 1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results + 1. Standard mode (use_batch_api=False): Uses parallel invoke for + immediate results 2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings Args: @@ -2356,7 +2357,7 @@ class BaseChatOpenAI(BaseChatModel): def _convert_input_to_messages( self, input_item: LanguageModelInput - ) -> List[BaseMessage]: + ) -> list[BaseMessage]: """Convert various input formats to a list of BaseMessage objects.""" if isinstance(input_item, list): # Already a list of messages diff --git a/libs/partners/openai/langchain_openai/chat_models/batch.py b/libs/partners/openai/langchain_openai/chat_models/batch.py index bdb19095769..105f49e8d67 100644 --- a/libs/partners/openai/langchain_openai/chat_models/batch.py +++ b/libs/partners/openai/langchain_openai/chat_models/batch.py @@ -5,7 +5,7 @@ from __future__ import annotations import json import time from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import uuid4 import openai @@ -62,9 +62,9 @@ class OpenAIBatchClient: def create_batch( self, - requests: List[Dict[str, Any]], + requests: list[dict[str, Any]], description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[dict[str, str]] = None, ) -> str: """ Create a new batch job with the OpenAI Batch API. @@ -104,7 +104,7 @@ class OpenAIBatchClient: except Exception as e: raise BatchError(f"Unexpected error creating batch: {e}") from e - def retrieve_batch(self, batch_id: str) -> Dict[str, Any]: + def retrieve_batch(self, batch_id: str) -> dict[str, Any]: """ Retrieve batch information by ID. @@ -146,7 +146,7 @@ class OpenAIBatchClient: batch_id: str, poll_interval: float = 10.0, timeout: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Poll batch status until completion or failure. @@ -189,7 +189,7 @@ class OpenAIBatchClient: time.sleep(poll_interval) - def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]: + def retrieve_batch_results(self, batch_id: str) -> list[dict[str, Any]]: """ Retrieve results from a completed batch. @@ -239,7 +239,7 @@ class OpenAIBatchClient: batch_id=batch_id, ) from e - def cancel_batch(self, batch_id: str) -> Dict[str, Any]: + def cancel_batch(self, batch_id: str) -> dict[str, Any]: """ Cancel a batch job. @@ -299,9 +299,9 @@ class OpenAIBatchProcessor: def create_batch( self, - messages_list: List[List[BaseMessage]], + messages_list: list[List[BaseMessage]], description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[dict[str, str]] = None, **kwargs: Any, ) -> str: """ @@ -337,7 +337,7 @@ class OpenAIBatchProcessor: batch_id: str, poll_interval: Optional[float] = None, timeout: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Poll batch status until completion or failure. @@ -358,7 +358,7 @@ class OpenAIBatchProcessor: timeout=timeout or self.timeout, ) - def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]: + def retrieve_batch_results(self, batch_id: str) -> list[ChatResult]: """ Retrieve and convert batch results to LangChain format. @@ -435,13 +435,13 @@ class OpenAIBatchProcessor: def process_batch( self, - messages_list: List[List[BaseMessage]], + messages_list: list[List[BaseMessage]], description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[dict[str, str]] = None, poll_interval: Optional[float] = None, timeout: Optional[float] = None, **kwargs: Any, - ) -> List[ChatResult]: + ) -> list[ChatResult]: """ Complete batch processing workflow: create, poll, and retrieve results. @@ -477,8 +477,8 @@ class OpenAIBatchProcessor: def create_batch_request( - messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any -) -> Dict[str, Any]: + messages: list[BaseMessage], model: str, custom_id: str, **kwargs: Any +) -> dict[str, Any]: """ Create a batch request object from LangChain messages. diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py index 44a8016a472..868a344034d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_batch_integration.py @@ -40,7 +40,10 @@ class TestBatchAPIIntegration: [HumanMessage(content="What is 2+2? Answer with just the number.")], [ HumanMessage( - content="What is the capital of France? Answer with just the city name." + content=( + "What is the capital of France? " + "Answer with just the city name." + ) ) ], ] @@ -206,7 +209,10 @@ class TestBatchAPIIntegration: "What is the largest planet? Answer with just the planet name.", [ HumanMessage( - content="What is the smallest planet? Answer with just the planet name." + content=( + "What is the smallest planet? " + "Answer with just the planet name." + ) ) ], ] @@ -346,7 +352,7 @@ class TestBatchAPIPerformance: ) end_time = time.time() - processing_time = end_time - start_time + _ = end_time - start_time # Verify all results assert len(results) == 10