mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-29 21:30:18 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -2276,18 +2276,19 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
@override
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
use_batch_api: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""
|
||||
Batch process multiple inputs using either standard API or OpenAI Batch API.
|
||||
|
||||
This method provides two processing modes:
|
||||
1. Standard mode (use_batch_api=False): Uses parallel invoke for immediate results
|
||||
1. Standard mode (use_batch_api=False): Uses parallel invoke for
|
||||
immediate results
|
||||
2. Batch API mode (use_batch_api=True): Uses OpenAI's Batch API for 50% cost savings
|
||||
|
||||
Args:
|
||||
@@ -2356,7 +2357,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
def _convert_input_to_messages(
|
||||
self, input_item: LanguageModelInput
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert various input formats to a list of BaseMessage objects."""
|
||||
if isinstance(input_item, list):
|
||||
# Already a list of messages
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
@@ -62,9 +62,9 @@ class OpenAIBatchClient:
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
requests: List[Dict[str, Any]],
|
||||
requests: list[dict[str, Any]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
metadata: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new batch job with the OpenAI Batch API.
|
||||
@@ -104,7 +104,7 @@ class OpenAIBatchClient:
|
||||
except Exception as e:
|
||||
raise BatchError(f"Unexpected error creating batch: {e}") from e
|
||||
|
||||
def retrieve_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
def retrieve_batch(self, batch_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve batch information by ID.
|
||||
|
||||
@@ -146,7 +146,7 @@ class OpenAIBatchClient:
|
||||
batch_id: str,
|
||||
poll_interval: float = 10.0,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
@@ -189,7 +189,7 @@ class OpenAIBatchClient:
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[Dict[str, Any]]:
|
||||
def retrieve_batch_results(self, batch_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve results from a completed batch.
|
||||
|
||||
@@ -239,7 +239,7 @@ class OpenAIBatchClient:
|
||||
batch_id=batch_id,
|
||||
) from e
|
||||
|
||||
def cancel_batch(self, batch_id: str) -> Dict[str, Any]:
|
||||
def cancel_batch(self, batch_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Cancel a batch job.
|
||||
|
||||
@@ -299,9 +299,9 @@ class OpenAIBatchProcessor:
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
messages_list: list[List[BaseMessage]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
metadata: Optional[dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -337,7 +337,7 @@ class OpenAIBatchProcessor:
|
||||
batch_id: str,
|
||||
poll_interval: Optional[float] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Poll batch status until completion or failure.
|
||||
|
||||
@@ -358,7 +358,7 @@ class OpenAIBatchProcessor:
|
||||
timeout=timeout or self.timeout,
|
||||
)
|
||||
|
||||
def retrieve_batch_results(self, batch_id: str) -> List[ChatResult]:
|
||||
def retrieve_batch_results(self, batch_id: str) -> list[ChatResult]:
|
||||
"""
|
||||
Retrieve and convert batch results to LangChain format.
|
||||
|
||||
@@ -435,13 +435,13 @@ class OpenAIBatchProcessor:
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
messages_list: List[List[BaseMessage]],
|
||||
messages_list: list[List[BaseMessage]],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
metadata: Optional[dict[str, str]] = None,
|
||||
poll_interval: Optional[float] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[ChatResult]:
|
||||
) -> list[ChatResult]:
|
||||
"""
|
||||
Complete batch processing workflow: create, poll, and retrieve results.
|
||||
|
||||
@@ -477,8 +477,8 @@ class OpenAIBatchProcessor:
|
||||
|
||||
|
||||
def create_batch_request(
|
||||
messages: List[BaseMessage], model: str, custom_id: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
messages: list[BaseMessage], model: str, custom_id: str, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a batch request object from LangChain messages.
|
||||
|
||||
|
||||
@@ -40,7 +40,10 @@ class TestBatchAPIIntegration:
|
||||
[HumanMessage(content="What is 2+2? Answer with just the number.")],
|
||||
[
|
||||
HumanMessage(
|
||||
content="What is the capital of France? Answer with just the city name."
|
||||
content=(
|
||||
"What is the capital of France? "
|
||||
"Answer with just the city name."
|
||||
)
|
||||
)
|
||||
],
|
||||
]
|
||||
@@ -206,7 +209,10 @@ class TestBatchAPIIntegration:
|
||||
"What is the largest planet? Answer with just the planet name.",
|
||||
[
|
||||
HumanMessage(
|
||||
content="What is the smallest planet? Answer with just the planet name."
|
||||
content=(
|
||||
"What is the smallest planet? "
|
||||
"Answer with just the planet name."
|
||||
)
|
||||
)
|
||||
],
|
||||
]
|
||||
@@ -346,7 +352,7 @@ class TestBatchAPIPerformance:
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
_ = end_time - start_time
|
||||
|
||||
# Verify all results
|
||||
assert len(results) == 10
|
||||
|
||||
Reference in New Issue
Block a user