openai[patch]: add explicit attribute for service tier (#31005)

This commit is contained in:
ccurme 2025-04-25 14:38:23 -04:00 committed by GitHub
parent ab871a7b39
commit 629b7a5a43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 3 deletions

View File

@ -119,6 +119,7 @@ def test_configurable() -> None:
"reasoning_effort": None,
"frequency_penalty": None,
"seed": None,
"service_tier": None,
"logprobs": None,
"top_logprobs": None,
"logit_bias": None,

View File

@ -538,6 +538,10 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during
invocation.
"""
service_tier: Optional[str] = None
"""Latency tier for request. Options are 'auto', 'default', or 'flex'. Relevant
for users of OpenAI's scale tier service.
"""
use_responses_api: Optional[bool] = None
"""Whether to use the Responses API instead of the Chat API.
@ -655,6 +659,7 @@ class BaseChatOpenAI(BaseChatModel):
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
"service_tier": self.service_tier,
}
params = {

View File

@ -215,12 +215,15 @@ async def test_openai_abatch_tags(use_responses_api: bool) -> None:
assert isinstance(token.text(), str)
@pytest.mark.scheduled
@pytest.mark.flaky(retries=3, delay=1)
def test_openai_invoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=MAX_TOKEN_COUNT) # type: ignore[call-arg]
llm = ChatOpenAI(
model="o4-mini",
service_tier="flex", # Also test service_tier
)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
result = llm.invoke("Hello", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
# assert no response headers if include_response_headers is not set

View File

@ -1732,3 +1732,9 @@ def test__construct_responses_api_input_multiple_message_types() -> None:
# assert no mutation has occurred
assert messages_copy == messages
def test_service_tier() -> None:
llm = ChatOpenAI(model="o4-mini", service_tier="flex")
payload = llm._get_request_payload([HumanMessage("Hello")])
assert payload["service_tier"] == "flex"