mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-29 21:30:18 +00:00
Apply patch [skip ci]
This commit is contained in:
@@ -2330,33 +2330,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
# Batch API processing (50% cost savings, polling required)
|
||||
results = llm.batch(inputs, use_batch_api=True)
|
||||
"""
|
||||
if use_batch_api:
|
||||
# Convert inputs to messages_list format expected by batch_create
|
||||
messages_list = []
|
||||
for input_item in inputs:
|
||||
if isinstance(input_item, list):
|
||||
# Already a list of messages
|
||||
messages_list.append(input_item)
|
||||
else:
|
||||
# Convert single input to list of messages
|
||||
messages = self._convert_input_to_messages(input_item)
|
||||
messages_list.append(messages)
|
||||
|
||||
# Create batch job and poll for results
|
||||
batch_id = self.batch_create(messages_list, **kwargs)
|
||||
chat_results = self.batch_retrieve(batch_id)
|
||||
|
||||
# Convert ChatResult objects to BaseMessage objects
|
||||
return [result.generations[0].message for result in chat_results]
|
||||
else:
|
||||
# Use the parent class's standard batch implementation
|
||||
return super().batch(
|
||||
inputs=inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _convert_input_to_messages(
|
||||
self, input_item: LanguageModelInput
|
||||
) -> list[BaseMessage]:
|
||||
|
||||
@@ -74,8 +74,8 @@ class TestBatchAPIIntegration:
|
||||
)
|
||||
|
||||
# Check that we got reasonable responses
|
||||
response1 = results[0].generations[0].str(message.content).strip()
|
||||
response2 = results[1].generations[0].str(message.content).strip()
|
||||
response1 = results[0].generations[0].message.content.strip()
|
||||
response2 = results[1].generations[0].message.content.strip()
|
||||
|
||||
# Basic sanity checks (responses should contain expected content)
|
||||
assert "4" in response1 or "four" in response1.lower()
|
||||
@@ -182,7 +182,7 @@ class TestBatchAPIIntegration:
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].str(message.content).strip()
|
||||
result_content = results[0].generations[0].message.content.strip()
|
||||
|
||||
# Should contain the answer
|
||||
assert "30" in result_content or "thirty" in result_content.lower()
|
||||
@@ -263,7 +263,7 @@ class TestBatchAPIIntegration:
|
||||
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
|
||||
|
||||
assert len(results) == 1
|
||||
result_content = results[0].generations[0].str(message.content).strip().lower()
|
||||
result_content = results[0].generations[0].message.content.strip().lower()
|
||||
assert "test successful" in result_content
|
||||
|
||||
|
||||
@@ -358,7 +358,7 @@ class TestBatchAPIPerformance:
|
||||
|
||||
# Check that we got reasonable math answers
|
||||
for i, result in enumerate(results, 1):
|
||||
content = result.generations[0].str(message.content).strip()
|
||||
content = result.generations[0].message.content.strip()
|
||||
expected_answer = str(i + i)
|
||||
assert expected_answer in content or str(i * 2) in content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user