mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
groq: Add service tier option to ChatGroq (#31801)
- Allows users to select a [flex processing](https://console.groq.com/docs/flex-processing) service tier
This commit is contained in:
parent
10ec5c8f02
commit
911b0b69ea
@ -375,6 +375,21 @@ class ChatGroq(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand")
|
||||
"""Optional parameter that you can include to specify the service tier you'd like to
|
||||
use for requests.
|
||||
|
||||
- ``'on_demand'``: Default.
|
||||
- ``'flex'``: On-demand processing when capacity is available, with rapid timeouts
|
||||
if resources are constrained. Provides balance between performance and reliability
|
||||
for workloads that don't require guaranteed processing.
|
||||
- ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those
|
||||
limits are exceeded
|
||||
|
||||
See the `Groq documentation
|
||||
<https://console.groq.com/docs/flex-processing>`__ for more details and a list of
|
||||
service tiers and descriptions.
|
||||
"""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
@ -534,7 +549,7 @@ class ChatGroq(BaseChatModel):
|
||||
**kwargs,
|
||||
}
|
||||
response = self.client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
return self._create_chat_result(response, params)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -555,7 +570,7 @@ class ChatGroq(BaseChatModel):
|
||||
**kwargs,
|
||||
}
|
||||
response = await self.async_client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
return self._create_chat_result(response, params)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -582,6 +597,8 @@ class ChatGroq(BaseChatModel):
|
||||
generation_info["model_name"] = self.model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
service_tier = params.get("service_tier") or self.service_tier
|
||||
generation_info["service_tier"] = service_tier
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
@ -623,6 +640,8 @@ class ChatGroq(BaseChatModel):
|
||||
generation_info["model_name"] = self.model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
service_tier = params.get("service_tier") or self.service_tier
|
||||
generation_info["service_tier"] = service_tier
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
@ -653,13 +672,16 @@ class ChatGroq(BaseChatModel):
|
||||
"stop": self.stop,
|
||||
"reasoning_format": self.reasoning_format,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
"service_tier": self.service_tier,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
def _create_chat_result(
|
||||
self, response: Union[dict, BaseModel], params: dict
|
||||
) -> ChatResult:
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
@ -689,6 +711,7 @@ class ChatGroq(BaseChatModel):
|
||||
"model_name": self.model_name,
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
@ -719,6 +742,8 @@ class ChatGroq(BaseChatModel):
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
if system_fingerprint:
|
||||
combined["system_fingerprint"] = system_fingerprint
|
||||
if self.service_tier:
|
||||
combined["service_tier"] = self.service_tier
|
||||
return combined
|
||||
|
||||
@deprecated(
|
||||
|
@ -4,6 +4,7 @@ import json
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pytest
|
||||
from groq import BadRequestError
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -467,6 +468,113 @@ def test_json_mode_structured_output() -> None:
|
||||
assert len(result.punchline) != 0
|
||||
|
||||
|
||||
def test_setting_service_tier_class() -> None:
|
||||
"""Test setting service tier defined at ChatGroq level."""
|
||||
message = HumanMessage(content="Welcome to the Groqetship")
|
||||
|
||||
# Initialization
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
|
||||
assert chat.service_tier == "auto"
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert response.response_metadata.get("service_tier") == "auto"
|
||||
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||
assert chat.service_tier == "flex"
|
||||
response = chat.invoke([message])
|
||||
assert response.response_metadata.get("service_tier") == "flex"
|
||||
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand")
|
||||
assert chat.service_tier == "on_demand"
|
||||
response = chat.invoke([message])
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
chat = ChatGroq(model=MODEL_NAME)
|
||||
assert chat.service_tier == "on_demand"
|
||||
response = chat.invoke([message])
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore
|
||||
|
||||
|
||||
def test_setting_service_tier_request() -> None:
|
||||
"""Test setting service tier defined at request level."""
|
||||
message = HumanMessage(content="Welcome to the Groqetship")
|
||||
chat = ChatGroq(model=MODEL_NAME)
|
||||
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier="auto",
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert response.response_metadata.get("service_tier") == "auto"
|
||||
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier="flex",
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "flex"
|
||||
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier="on_demand",
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
assert chat.service_tier == "on_demand"
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
# If an `invoke` call is made with no service tier, we fall back to the class level
|
||||
# setting
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "auto"
|
||||
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier="on_demand",
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
with pytest.raises(BadRequestError):
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier="invalid",
|
||||
)
|
||||
|
||||
response = chat.invoke(
|
||||
[message],
|
||||
service_tier=None,
|
||||
)
|
||||
assert response.response_metadata.get("service_tier") == "auto"
|
||||
|
||||
|
||||
def test_setting_service_tier_streaming() -> None:
|
||||
"""Test service tier settings for streaming calls."""
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||
chunks = list(chat.stream("Why is the sky blue?", service_tier="auto"))
|
||||
|
||||
assert chunks[-1].response_metadata.get("service_tier") == "auto"
|
||||
|
||||
|
||||
async def test_setting_service_tier_request_async() -> None:
|
||||
"""Test async setting of service tier at the request level."""
|
||||
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||
response = await chat.ainvoke("Hello!", service_tier="on_demand")
|
||||
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
|
||||
# Groq does not currently support N > 1
|
||||
# @pytest.mark.scheduled
|
||||
# def test_chat_multiple_completions() -> None:
|
||||
|
@ -19,6 +19,7 @@
|
||||
'model_name': 'llama-3.1-8b-instant',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'service_tier': 'on_demand',
|
||||
'stop': list([
|
||||
]),
|
||||
'temperature': 1e-08,
|
||||
|
Loading…
Reference in New Issue
Block a user