groq: Add service tier option to ChatGroq (#31801)

- Allows users to select a [flex
processing](https://console.groq.com/docs/flex-processing) service tier
This commit is contained in:
Mason Daugherty 2025-07-03 10:11:18 -04:00 committed by GitHub
parent 10ec5c8f02
commit 911b0b69ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 137 additions and 3 deletions

View File

@ -375,6 +375,21 @@ class ChatGroq(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand")
"""Optional parameter that you can include to specify the service tier you'd like to
use for requests.
- ``'on_demand'``: Default.
- ``'flex'``: On-demand processing when capacity is available, with rapid timeouts
if resources are constrained. Provides balance between performance and reliability
for workloads that don't require guaranteed processing.
- ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those
limits are exceeded
See the `Groq documentation
<https://console.groq.com/docs/flex-processing>`__ for more details and a list of
service tiers and descriptions.
"""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
@ -534,7 +549,7 @@ class ChatGroq(BaseChatModel):
**kwargs,
}
response = self.client.create(messages=message_dicts, **params)
return self._create_chat_result(response)
return self._create_chat_result(response, params)
async def _agenerate(
self,
@ -555,7 +570,7 @@ class ChatGroq(BaseChatModel):
**kwargs,
}
response = await self.async_client.create(messages=message_dicts, **params)
return self._create_chat_result(response)
return self._create_chat_result(response, params)
def _stream(
self,
@ -582,6 +597,8 @@ class ChatGroq(BaseChatModel):
generation_info["model_name"] = self.model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
@ -623,6 +640,8 @@ class ChatGroq(BaseChatModel):
generation_info["model_name"] = self.model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
@ -653,13 +672,16 @@ class ChatGroq(BaseChatModel):
"stop": self.stop,
"reasoning_format": self.reasoning_format,
"reasoning_effort": self.reasoning_effort,
"service_tier": self.service_tier,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
def _create_chat_result(
self, response: Union[dict, BaseModel], params: dict
) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.model_dump()
@ -689,6 +711,7 @@ class ChatGroq(BaseChatModel):
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
@ -719,6 +742,8 @@ class ChatGroq(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
if self.service_tier:
combined["service_tier"] = self.service_tier
return combined
@deprecated(

View File

@ -4,6 +4,7 @@ import json
from typing import Any, Optional, cast
import pytest
from groq import BadRequestError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -467,6 +468,113 @@ def test_json_mode_structured_output() -> None:
assert len(result.punchline) != 0
def test_setting_service_tier_class() -> None:
"""Test setting service tier defined at ChatGroq level."""
message = HumanMessage(content="Welcome to the Groqetship")
# Initialization
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
assert chat.service_tier == "auto"
response = chat.invoke([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
assert response.response_metadata.get("service_tier") == "auto"
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
assert chat.service_tier == "flex"
response = chat.invoke([message])
assert response.response_metadata.get("service_tier") == "flex"
chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand")
assert chat.service_tier == "on_demand"
response = chat.invoke([message])
assert response.response_metadata.get("service_tier") == "on_demand"
chat = ChatGroq(model=MODEL_NAME)
assert chat.service_tier == "on_demand"
response = chat.invoke([message])
assert response.response_metadata.get("service_tier") == "on_demand"
with pytest.raises(ValueError):
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore
with pytest.raises(ValueError):
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore
def test_setting_service_tier_request() -> None:
"""Test setting service tier defined at request level."""
message = HumanMessage(content="Welcome to the Groqetship")
chat = ChatGroq(model=MODEL_NAME)
response = chat.invoke(
[message],
service_tier="auto",
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
assert response.response_metadata.get("service_tier") == "auto"
response = chat.invoke(
[message],
service_tier="flex",
)
assert response.response_metadata.get("service_tier") == "flex"
response = chat.invoke(
[message],
service_tier="on_demand",
)
assert response.response_metadata.get("service_tier") == "on_demand"
assert chat.service_tier == "on_demand"
response = chat.invoke(
[message],
)
assert response.response_metadata.get("service_tier") == "on_demand"
# If an `invoke` call is made with no service tier, we fall back to the class level
# setting
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
response = chat.invoke(
[message],
)
assert response.response_metadata.get("service_tier") == "auto"
response = chat.invoke(
[message],
service_tier="on_demand",
)
assert response.response_metadata.get("service_tier") == "on_demand"
with pytest.raises(BadRequestError):
response = chat.invoke(
[message],
service_tier="invalid",
)
response = chat.invoke(
[message],
service_tier=None,
)
assert response.response_metadata.get("service_tier") == "auto"
def test_setting_service_tier_streaming() -> None:
"""Test service tier settings for streaming calls."""
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
chunks = list(chat.stream("Why is the sky blue?", service_tier="auto"))
assert chunks[-1].response_metadata.get("service_tier") == "auto"
async def test_setting_service_tier_request_async() -> None:
"""Test async setting of service tier at the request level."""
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
response = await chat.ainvoke("Hello!", service_tier="on_demand")
assert response.response_metadata.get("service_tier") == "on_demand"
# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:

View File

@ -19,6 +19,7 @@
'model_name': 'llama-3.1-8b-instant',
'n': 1,
'request_timeout': 60.0,
'service_tier': 'on_demand',
'stop': list([
]),
'temperature': 1e-08,