mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 22:15:08 +00:00
groq: Add service tier option to ChatGroq (#31801)
- Allows users to select a [flex processing](https://console.groq.com/docs/flex-processing) service tier
This commit is contained in:
parent
10ec5c8f02
commit
911b0b69ea
@ -375,6 +375,21 @@ class ChatGroq(BaseChatModel):
|
|||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
|
service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand")
|
||||||
|
"""Optional parameter that you can include to specify the service tier you'd like to
|
||||||
|
use for requests.
|
||||||
|
|
||||||
|
- ``'on_demand'``: Default.
|
||||||
|
- ``'flex'``: On-demand processing when capacity is available, with rapid timeouts
|
||||||
|
if resources are constrained. Provides balance between performance and reliability
|
||||||
|
for workloads that don't require guaranteed processing.
|
||||||
|
- ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those
|
||||||
|
limits are exceeded
|
||||||
|
|
||||||
|
See the `Groq documentation
|
||||||
|
<https://console.groq.com/docs/flex-processing>`__ for more details and a list of
|
||||||
|
service tiers and descriptions.
|
||||||
|
"""
|
||||||
default_headers: Union[Mapping[str, str], None] = None
|
default_headers: Union[Mapping[str, str], None] = None
|
||||||
default_query: Union[Mapping[str, object], None] = None
|
default_query: Union[Mapping[str, object], None] = None
|
||||||
# Configure a custom httpx client. See the
|
# Configure a custom httpx client. See the
|
||||||
@ -534,7 +549,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = self.client.create(messages=message_dicts, **params)
|
response = self.client.create(messages=message_dicts, **params)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response, params)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
@ -555,7 +570,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = await self.async_client.create(messages=message_dicts, **params)
|
response = await self.async_client.create(messages=message_dicts, **params)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response, params)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -582,6 +597,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
generation_info["model_name"] = self.model_name
|
generation_info["model_name"] = self.model_name
|
||||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||||
generation_info["system_fingerprint"] = system_fingerprint
|
generation_info["system_fingerprint"] = system_fingerprint
|
||||||
|
service_tier = params.get("service_tier") or self.service_tier
|
||||||
|
generation_info["service_tier"] = service_tier
|
||||||
logprobs = choice.get("logprobs")
|
logprobs = choice.get("logprobs")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["logprobs"] = logprobs
|
||||||
@ -623,6 +640,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
generation_info["model_name"] = self.model_name
|
generation_info["model_name"] = self.model_name
|
||||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||||
generation_info["system_fingerprint"] = system_fingerprint
|
generation_info["system_fingerprint"] = system_fingerprint
|
||||||
|
service_tier = params.get("service_tier") or self.service_tier
|
||||||
|
generation_info["service_tier"] = service_tier
|
||||||
logprobs = choice.get("logprobs")
|
logprobs = choice.get("logprobs")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["logprobs"] = logprobs
|
||||||
@ -653,13 +672,16 @@ class ChatGroq(BaseChatModel):
|
|||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
"reasoning_format": self.reasoning_format,
|
"reasoning_format": self.reasoning_format,
|
||||||
"reasoning_effort": self.reasoning_effort,
|
"reasoning_effort": self.reasoning_effort,
|
||||||
|
"service_tier": self.service_tier,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
params["max_tokens"] = self.max_tokens
|
params["max_tokens"] = self.max_tokens
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
def _create_chat_result(
|
||||||
|
self, response: Union[dict, BaseModel], params: dict
|
||||||
|
) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.model_dump()
|
response = response.model_dump()
|
||||||
@ -689,6 +711,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||||
}
|
}
|
||||||
|
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
@ -719,6 +742,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
if system_fingerprint:
|
if system_fingerprint:
|
||||||
combined["system_fingerprint"] = system_fingerprint
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
|
if self.service_tier:
|
||||||
|
combined["service_tier"] = self.service_tier
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from groq import BadRequestError
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -467,6 +468,113 @@ def test_json_mode_structured_output() -> None:
|
|||||||
assert len(result.punchline) != 0
|
assert len(result.punchline) != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_setting_service_tier_class() -> None:
|
||||||
|
"""Test setting service tier defined at ChatGroq level."""
|
||||||
|
message = HumanMessage(content="Welcome to the Groqetship")
|
||||||
|
|
||||||
|
# Initialization
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
|
||||||
|
assert chat.service_tier == "auto"
|
||||||
|
response = chat.invoke([message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.response_metadata.get("service_tier") == "auto"
|
||||||
|
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||||
|
assert chat.service_tier == "flex"
|
||||||
|
response = chat.invoke([message])
|
||||||
|
assert response.response_metadata.get("service_tier") == "flex"
|
||||||
|
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand")
|
||||||
|
assert chat.service_tier == "on_demand"
|
||||||
|
response = chat.invoke([message])
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
chat = ChatGroq(model=MODEL_NAME)
|
||||||
|
assert chat.service_tier == "on_demand"
|
||||||
|
response = chat.invoke([message])
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_setting_service_tier_request() -> None:
|
||||||
|
"""Test setting service tier defined at request level."""
|
||||||
|
message = HumanMessage(content="Welcome to the Groqetship")
|
||||||
|
chat = ChatGroq(model=MODEL_NAME)
|
||||||
|
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier="auto",
|
||||||
|
)
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.response_metadata.get("service_tier") == "auto"
|
||||||
|
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier="flex",
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "flex"
|
||||||
|
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier="on_demand",
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
assert chat.service_tier == "on_demand"
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
# If an `invoke` call is made with no service tier, we fall back to the class level
|
||||||
|
# setting
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "auto"
|
||||||
|
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier="on_demand",
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
with pytest.raises(BadRequestError):
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = chat.invoke(
|
||||||
|
[message],
|
||||||
|
service_tier=None,
|
||||||
|
)
|
||||||
|
assert response.response_metadata.get("service_tier") == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def test_setting_service_tier_streaming() -> None:
|
||||||
|
"""Test service tier settings for streaming calls."""
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||||
|
chunks = list(chat.stream("Why is the sky blue?", service_tier="auto"))
|
||||||
|
|
||||||
|
assert chunks[-1].response_metadata.get("service_tier") == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setting_service_tier_request_async() -> None:
|
||||||
|
"""Test async setting of service tier at the request level."""
|
||||||
|
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
|
||||||
|
response = await chat.ainvoke("Hello!", service_tier="on_demand")
|
||||||
|
|
||||||
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
|
|
||||||
# Groq does not currently support N > 1
|
# Groq does not currently support N > 1
|
||||||
# @pytest.mark.scheduled
|
# @pytest.mark.scheduled
|
||||||
# def test_chat_multiple_completions() -> None:
|
# def test_chat_multiple_completions() -> None:
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
'model_name': 'llama-3.1-8b-instant',
|
'model_name': 'llama-3.1-8b-instant',
|
||||||
'n': 1,
|
'n': 1,
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
|
'service_tier': 'on_demand',
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
'temperature': 1e-08,
|
'temperature': 1e-08,
|
||||||
|
Loading…
Reference in New Issue
Block a user