Apply patch [skip ci]

This commit is contained in:
open-swe[bot]
2025-08-11 20:37:14 +00:00
parent a9b8e5cd18
commit 9bc4c99c1c
3 changed files with 31 additions and 24 deletions

View File

@@ -2276,18 +2276,19 @@ class BaseChatOpenAI(BaseChatModel):
@override
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
use_batch_api: bool = False,
**kwargs: Any,
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""
Batch process multiple inputs using either standard API or OpenAI Batch API.
This method provides two processing modes:
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
1. Standard mode (use_batch_api=False): Uses parallel invoke for
immediate results
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
Args:
@@ -2356,7 +2357,7 @@ class BaseChatOpenAI(BaseChatModel):
def _convert_input_to_messages(
self, input_item: LanguageModelInput
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Convert various input formats to a list of BaseMessage objects."""
if isinstance(input_item, list):
# Already a list of messages

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json
import time
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from uuid import uuid4
import openai
@@ -62,9 +62,9 @@ class OpenAIBatchClient:
def create_batch(
self,
requests: List[Dict[str, Any]],
requests: list[dict[str, Any]],
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
metadata: Optional[dict[str, str]] = None,
) -> str:
"""
Create a new batch job with the OpenAI Batch API.
@@ -104,7 +104,7 @@ class OpenAIBatchClient:
except Exception as e:
raise BatchError(f"Unexpected error creating batch: {e}") from e
def retrieve_batch(self, batch_id: str) -> Dict[str, Any]:
def retrieve_batch(self, batch_id: str) -> dict[str, Any]:
"""
Retrieve batch information by ID.
@@ -146,7 +146,7 @@ class OpenAIBatchClient:
batch_id: str,
poll_interval: float = 10.0,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Poll batch status until completion or failure.
@@ -189,7 +189,7 @@ class OpenAIBatchClient:
time.sleep(poll_interval)
def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]:
def retrieve_batch_results(self, batch_id: str) -> list[dict[str, Any]]:
"""
Retrieve results from a completed batch.
@@ -239,7 +239,7 @@ class OpenAIBatchClient:
batch_id=batch_id,
) from e
def cancel_batch(self, batch_id: str) -> Dict[str, Any]:
def cancel_batch(self, batch_id: str) -> dict[str, Any]:
"""
Cancel a batch job.
@@ -299,9 +299,9 @@ class OpenAIBatchProcessor:
def create_batch(
self,
messages_list: List[List[BaseMessage]],
messages_list: list[List[BaseMessage]],
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
metadata: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> str:
"""
@@ -337,7 +337,7 @@ class OpenAIBatchProcessor:
batch_id: str,
poll_interval: Optional[float] = None,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Poll batch status until completion or failure.
@@ -358,7 +358,7 @@ class OpenAIBatchProcessor:
timeout=timeout or self.timeout,
)
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
def retrieve_batch_results(self, batch_id: str) -> list[ChatResult]:
"""
Retrieve and convert batch results to LangChain format.
@@ -435,13 +435,13 @@ class OpenAIBatchProcessor:
def process_batch(
self,
messages_list: List[List[BaseMessage]],
messages_list: list[List[BaseMessage]],
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
metadata: Optional[dict[str, str]] = None,
poll_interval: Optional[float] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[ChatResult]:
) -> list[ChatResult]:
"""
Complete batch processing workflow: create, poll, and retrieve results.
@@ -477,8 +477,8 @@ class OpenAIBatchProcessor:
def create_batch_request(
messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any
) -> Dict[str, Any]:
messages: list[BaseMessage], model: str, custom_id: str, **kwargs: Any
) -> dict[str, Any]:
"""
Create a batch request object from LangChain messages.

View File

@@ -40,7 +40,10 @@ class TestBatchAPIIntegration:
[HumanMessage(content="What is 2+2? Answer with just the number.")],
[
HumanMessage(
content="What is the capital of France? Answer with just the city name."
content=(
"What is the capital of France? "
"Answer with just the city name."
)
)
],
]
@@ -206,7 +209,10 @@ class TestBatchAPIIntegration:
"What is the largest planet? Answer with just the planet name.",
[
HumanMessage(
content="What is the smallest planet? Answer with just the planet name."
content=(
"What is the smallest planet? "
"Answer with just the planet name."
)
)
],
]
@@ -346,7 +352,7 @@ class TestBatchAPIPerformance:
)
end_time = time.time()
processing_time = end_time - start_time
_ = end_time - start_time
# Verify all results
assert len(results) == 10