From 911b0b69ea7a07a6269fbbed7f7b6e73c2310aea Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 3 Jul 2025 10:11:18 -0400 Subject: [PATCH] groq: Add service tier option to ChatGroq (#31801) - Allows users to select a [flex processing](https://console.groq.com/docs/flex-processing) service tier --- .../groq/langchain_groq/chat_models.py | 31 ++++- .../integration_tests/test_chat_models.py | 108 ++++++++++++++++++ .../__snapshots__/test_standard.ambr | 1 + 3 files changed, 137 insertions(+), 3 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index e285e613d84..85b200e6ace 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -375,6 +375,21 @@ class ChatGroq(BaseChatModel): """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" + service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand") + """Optional parameter that you can include to specify the service tier you'd like to + use for requests. + + - ``'on_demand'``: Default. + - ``'flex'``: On-demand processing when capacity is available, with rapid timeouts + if resources are constrained. Provides balance between performance and reliability + for workloads that don't require guaranteed processing. + - ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those + limits are exceeded + + See the `Groq documentation + `__ for more details and a list of + service tiers and descriptions. + """ default_headers: Union[Mapping[str, str], None] = None default_query: Union[Mapping[str, object], None] = None # Configure a custom httpx client. See the @@ -534,7 +549,7 @@ class ChatGroq(BaseChatModel): **kwargs, } response = self.client.create(messages=message_dicts, **params) - return self._create_chat_result(response) + return self._create_chat_result(response, params) async def _agenerate( self, @@ -555,7 +570,7 @@ class ChatGroq(BaseChatModel): **kwargs, } response = await self.async_client.create(messages=message_dicts, **params) - return self._create_chat_result(response) + return self._create_chat_result(response, params) def _stream( self, @@ -582,6 +597,8 @@ class ChatGroq(BaseChatModel): generation_info["model_name"] = self.model_name if system_fingerprint := chunk.get("system_fingerprint"): generation_info["system_fingerprint"] = system_fingerprint + service_tier = params.get("service_tier") or self.service_tier + generation_info["service_tier"] = service_tier logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -623,6 +640,8 @@ class ChatGroq(BaseChatModel): generation_info["model_name"] = self.model_name if system_fingerprint := chunk.get("system_fingerprint"): generation_info["system_fingerprint"] = system_fingerprint + service_tier = params.get("service_tier") or self.service_tier + generation_info["service_tier"] = service_tier logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -653,13 +672,16 @@ class ChatGroq(BaseChatModel): "stop": self.stop, "reasoning_format": self.reasoning_format, "reasoning_effort": self.reasoning_effort, + "service_tier": self.service_tier, **self.model_kwargs, } if self.max_tokens is not None: params["max_tokens"] = self.max_tokens return params - def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: + def _create_chat_result( + self, response: Union[dict, BaseModel], params: dict + ) -> ChatResult: generations = [] if not isinstance(response, dict): response = response.model_dump() @@ -689,6 +711,7 @@ class ChatGroq(BaseChatModel): "model_name": self.model_name, "system_fingerprint": response.get("system_fingerprint", ""), } + llm_output["service_tier"] = params.get("service_tier") or self.service_tier return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( @@ -719,6 +742,8 @@ class ChatGroq(BaseChatModel): combined = {"token_usage": overall_token_usage, "model_name": self.model_name} if system_fingerprint: combined["system_fingerprint"] = system_fingerprint + if self.service_tier: + combined["service_tier"] = self.service_tier return combined @deprecated( diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index f344bd4e334..5ef2c3045d9 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -4,6 +4,7 @@ import json from typing import Any, Optional, cast import pytest +from groq import BadRequestError from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -467,6 +468,113 @@ def test_json_mode_structured_output() -> None: assert len(result.punchline) != 0 +def test_setting_service_tier_class() -> None: + """Test setting service tier defined at ChatGroq level.""" + message = HumanMessage(content="Welcome to the Groqetship") + + # Initialization + chat = ChatGroq(model=MODEL_NAME, service_tier="auto") + assert chat.service_tier == "auto" + response = chat.invoke([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + assert response.response_metadata.get("service_tier") == "auto" + + chat = ChatGroq(model=MODEL_NAME, service_tier="flex") + assert chat.service_tier == "flex" + response = chat.invoke([message]) + assert response.response_metadata.get("service_tier") == "flex" + + chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand") + assert chat.service_tier == "on_demand" + response = chat.invoke([message]) + assert response.response_metadata.get("service_tier") == "on_demand" + + chat = ChatGroq(model=MODEL_NAME) + assert chat.service_tier == "on_demand" + response = chat.invoke([message]) + assert response.response_metadata.get("service_tier") == "on_demand" + + with pytest.raises(ValueError): + ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore + with pytest.raises(ValueError): + ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore + + +def test_setting_service_tier_request() -> None: + """Test setting service tier defined at request level.""" + message = HumanMessage(content="Welcome to the Groqetship") + chat = ChatGroq(model=MODEL_NAME) + + response = chat.invoke( + [message], + service_tier="auto", + ) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + assert response.response_metadata.get("service_tier") == "auto" + + response = chat.invoke( + [message], + service_tier="flex", + ) + assert response.response_metadata.get("service_tier") == "flex" + + response = chat.invoke( + [message], + service_tier="on_demand", + ) + assert response.response_metadata.get("service_tier") == "on_demand" + + assert chat.service_tier == "on_demand" + response = chat.invoke( + [message], + ) + assert response.response_metadata.get("service_tier") == "on_demand" + + # If an `invoke` call is made with no service tier, we fall back to the class level + # setting + chat = ChatGroq(model=MODEL_NAME, service_tier="auto") + response = chat.invoke( + [message], + ) + assert response.response_metadata.get("service_tier") == "auto" + + response = chat.invoke( + [message], + service_tier="on_demand", + ) + assert response.response_metadata.get("service_tier") == "on_demand" + + with pytest.raises(BadRequestError): + response = chat.invoke( + [message], + service_tier="invalid", + ) + + response = chat.invoke( + [message], + service_tier=None, + ) + assert response.response_metadata.get("service_tier") == "auto" + + +def test_setting_service_tier_streaming() -> None: + """Test service tier settings for streaming calls.""" + chat = ChatGroq(model=MODEL_NAME, service_tier="flex") + chunks = list(chat.stream("Why is the sky blue?", service_tier="auto")) + + assert chunks[-1].response_metadata.get("service_tier") == "auto" + + +async def test_setting_service_tier_request_async() -> None: + """Test async setting of service tier at the request level.""" + chat = ChatGroq(model=MODEL_NAME, service_tier="flex") + response = await chat.ainvoke("Hello!", service_tier="on_demand") + + assert response.response_metadata.get("service_tier") == "on_demand" + + # Groq does not currently support N > 1 # @pytest.mark.scheduled # def test_chat_multiple_completions() -> None: diff --git a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr index 7b8db708167..055fad288c1 100644 --- a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr @@ -19,6 +19,7 @@ 'model_name': 'llama-3.1-8b-instant', 'n': 1, 'request_timeout': 60.0, + 'service_tier': 'on_demand', 'stop': list([ ]), 'temperature': 1e-08,