Apply patch [skip ci]

This commit is contained in:
open-swe[bot]
2025-08-11 20:52:31 +00:00
parent d7b3288e9f
commit d6171c1ef5
2 changed files with 5 additions and 32 deletions

View File

@@ -2330,33 +2330,6 @@ class BaseChatOpenAI(BaseChatModel):
# Batch API processing (50% cost savings, polling required)
results = llm.batch(inputs, use_batch_api=True)
"""
if use_batch_api:
# Convert inputs to messages_list format expected by batch_create
messages_list = []
for input_item in inputs:
if isinstance(input_item, list):
# Already a list of messages
messages_list.append(input_item)
else:
# Convert single input to list of messages
messages = self._convert_input_to_messages(input_item)
messages_list.append(messages)
# Create batch job and poll for results
batch_id = self.batch_create(messages_list, **kwargs)
chat_results = self.batch_retrieve(batch_id)
# Convert ChatResult objects to BaseMessage objects
return [result.generations[0].message for result in chat_results]
else:
# Use the parent class's standard batch implementation
return super().batch(
inputs=inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def _convert_input_to_messages(
self, input_item: LanguageModelInput
) -> list[BaseMessage]:

View File

@@ -74,8 +74,8 @@ class TestBatchAPIIntegration:
)
# Check that we got reasonable responses
response1 = results[0].generations[0].str(message.content).strip()
response2 = results[1].generations[0].str(message.content).strip()
response1 = results[0].generations[0].message.content.strip()
response2 = results[1].generations[0].message.content.strip()
# Basic sanity checks (responses should contain expected content)
assert "4" in response1 or "four" in response1.lower()
@@ -182,7 +182,7 @@ class TestBatchAPIIntegration:
)
assert len(results) == 1
result_content = results[0].generations[0].str(message.content).strip()
result_content = results[0].generations[0].message.content.strip()
# Should contain the answer
assert "30" in result_content or "thirty" in result_content.lower()
@@ -263,7 +263,7 @@ class TestBatchAPIIntegration:
results = self.llm.batch_retrieve(batch_id, timeout=1800.0)
assert len(results) == 1
result_content = results[0].generations[0].str(message.content).strip().lower()
result_content = results[0].generations[0].message.content.strip().lower()
assert "test successful" in result_content
@@ -358,7 +358,7 @@ class TestBatchAPIPerformance:
# Check that we got reasonable math answers
for i, result in enumerate(results, 1):
content = result.generations[0].str(message.content).strip()
content = result.generations[0].message.content.strip()
expected_answer = str(i + i)
assert expected_answer in content or str(i * 2) in content