groq[patch]: warn if model is not specified (#30161)

Groq is retiring `mixtral-8x7b-32768`, which is currently the default
model for ChatGroq, on March 20. Here we emit a warning if the model is
not specified explicitly.

A version 0.3.0 will be released ahead of March 20 that removes the
default altogether.
This commit is contained in:
ccurme 2025-03-07 15:21:13 -05:00 committed by GitHub
parent 3444e587ee
commit 74e7772a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 35 deletions

View File

@ -88,6 +88,8 @@ from typing_extensions import Self
from langchain_groq.version import __version__
WARNED_DEFAULT_MODEL = False
class ChatGroq(BaseChatModel):
"""`Groq` Chat large language models API.
@ -109,7 +111,7 @@ class ChatGroq(BaseChatModel):
Key init args completion params:
model: str
Name of Groq model to use. E.g. "mixtral-8x7b-32768".
Name of Groq model to use. E.g. "llama-3.1-8b-instant".
temperature: float
Sampling temperature. Ranges from 0.0 to 1.0.
max_tokens: Optional[int]
@ -140,7 +142,7 @@ class ChatGroq(BaseChatModel):
from langchain_groq import ChatGroq
llm = ChatGroq(
model="mixtral-8x7b-32768",
model="llama-3.1-8b-instant",
temperature=0.0,
max_retries=2,
# other params...
@ -164,7 +166,7 @@ class ChatGroq(BaseChatModel):
response_metadata={'token_usage': {'completion_tokens': 38,
'prompt_tokens': 28, 'total_tokens': 66, 'completion_time':
0.057975474, 'prompt_time': 0.005366091, 'queue_time': None,
'total_time': 0.063341565}, 'model_name': 'mixtral-8x7b-32768',
'total_time': 0.063341565}, 'model_name': 'llama-3.1-8b-instant',
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
'logprobs': None}, id='run-ecc71d70-e10c-4b69-8b8c-b8027d95d4b8-0')
@ -222,7 +224,7 @@ class ChatGroq(BaseChatModel):
response_metadata={'token_usage': {'completion_tokens': 53,
'prompt_tokens': 28, 'total_tokens': 81, 'completion_time':
0.083623752, 'prompt_time': 0.007365126, 'queue_time': None,
'total_time': 0.090988878}, 'model_name': 'mixtral-8x7b-32768',
'total_time': 0.090988878}, 'model_name': 'llama-3.1-8b-instant',
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
'logprobs': None}, id='run-897f3391-1bea-42e2-82e0-686e2367bcf8-0')
@ -295,7 +297,7 @@ class ChatGroq(BaseChatModel):
'prompt_time': 0.007518279,
'queue_time': None,
'total_time': 0.11947467},
'model_name': 'mixtral-8x7b-32768',
'model_name': 'llama-3.1-8b-instant',
'system_fingerprint': 'fp_c5f20b5bb1',
'finish_reason': 'stop',
'logprobs': None}
@ -351,6 +353,27 @@ class ChatGroq(BaseChatModel):
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def warn_default_model(cls, values: Dict[str, Any]) -> Any:
"""Warning anticipating removal of default model."""
# TODO(ccurme): remove this warning in 0.3.0 when default model is removed
global WARNED_DEFAULT_MODEL
if (
"model" not in values
and "model_name" not in values
and not WARNED_DEFAULT_MODEL
):
warnings.warn(
"Groq is retiring the default model for ChatGroq, mixtral-8x7b-32768, "
"on March 20, 2025. Requests with the default model will start failing "
"on that date. Version 0.3.0 of langchain-groq will remove the "
"default. Please specify `model` explicitly, e.g., "
"`model='mistral-saba-24b'` or `model='llama-3.3-70b-versatile'`.",
)
WARNED_DEFAULT_MODEL = True
return values
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:

View File

@ -21,6 +21,8 @@ from tests.unit_tests.fake.callbacks import (
FakeCallbackHandlerWithChatStart,
)
MODEL_NAME = "llama-3.3-70b-versatile"
#
# Smoke test Runnable interface
@ -28,7 +30,8 @@ from tests.unit_tests.fake.callbacks import (
@pytest.mark.scheduled
def test_invoke() -> None:
"""Test Chat wrapper."""
chat = ChatGroq( # type: ignore[call-arg]
chat = ChatGroq(
model=MODEL_NAME,
temperature=0.7,
base_url=None,
groq_proxy=None,
@ -49,7 +52,7 @@ def test_invoke() -> None:
@pytest.mark.scheduled
async def test_ainvoke() -> None:
"""Test ainvoke tokens from ChatGroq."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
assert isinstance(result, BaseMessage)
@ -59,7 +62,7 @@ async def test_ainvoke() -> None:
@pytest.mark.scheduled
def test_batch() -> None:
"""Test batch tokens from ChatGroq."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
for token in result:
@ -70,7 +73,7 @@ def test_batch() -> None:
@pytest.mark.scheduled
async def test_abatch() -> None:
"""Test abatch tokens from ChatGroq."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
for token in result:
@ -81,7 +84,7 @@ async def test_abatch() -> None:
@pytest.mark.scheduled
async def test_stream() -> None:
"""Test streaming tokens from Groq."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
for token in chat.stream("Welcome to the Groqetship!"):
assert isinstance(token, BaseMessageChunk)
@ -91,7 +94,7 @@ async def test_stream() -> None:
@pytest.mark.scheduled
async def test_astream() -> None:
"""Test streaming tokens from Groq."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
@ -124,7 +127,7 @@ async def test_astream() -> None:
def test_generate() -> None:
"""Test sync generate."""
n = 1
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
message = HumanMessage(content="Hello", n=1)
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
@ -143,7 +146,7 @@ def test_generate() -> None:
async def test_agenerate() -> None:
"""Test async generation."""
n = 1
chat = ChatGroq(max_tokens=10, n=1) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10, n=1)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
@ -165,7 +168,8 @@ async def test_agenerate() -> None:
def test_invoke_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatGroq( # type: ignore[call-arg]
chat = ChatGroq(
model=MODEL_NAME,
max_tokens=2,
streaming=True,
temperature=0,
@ -181,7 +185,8 @@ def test_invoke_streaming() -> None:
async def test_agenerate_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandlerWithChatStart()
chat = ChatGroq( # type: ignore[call-arg]
chat = ChatGroq(
model=MODEL_NAME,
max_tokens=10,
streaming=True,
temperature=0,
@ -220,7 +225,8 @@ def test_streaming_generation_info() -> None:
self.saved_things["generation"] = args[0]
callback = _FakeCallback()
chat = ChatGroq( # type: ignore[call-arg]
chat = ChatGroq(
model=MODEL_NAME,
max_tokens=2,
temperature=0,
callbacks=[callback],
@ -234,7 +240,7 @@ def test_streaming_generation_info() -> None:
def test_system_message() -> None:
"""Test ChatGroq wrapper with system message."""
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat.invoke([system_message, human_message])
@ -242,10 +248,9 @@ def test_system_message() -> None:
assert isinstance(response.content, str)
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
def test_tool_choice() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model=MODEL_NAME)
class MyTool(BaseModel):
name: str
@ -273,10 +278,9 @@ def test_tool_choice() -> None:
assert tool_call["args"] == {"name": "Erick", "age": 27}
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
def test_tool_choice_bool() -> None:
"""Test that tool choice is respected just passing in True."""
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model=MODEL_NAME)
class MyTool(BaseModel):
name: str
@ -301,7 +305,7 @@ def test_tool_choice_bool() -> None:
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
def test_streaming_tool_call() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model=MODEL_NAME)
class MyTool(BaseModel):
name: str
@ -339,7 +343,7 @@ def test_streaming_tool_call() -> None:
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
async def test_astreaming_tool_call() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model=MODEL_NAME)
class MyTool(BaseModel):
name: str
@ -384,7 +388,7 @@ def test_json_mode_structured_output() -> None:
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
chat = ChatGroq().with_structured_output(Joke, method="json_mode") # type: ignore[call-arg]
chat = ChatGroq(model=MODEL_NAME).with_structured_output(Joke, method="json_mode")
result = chat.invoke(
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
)

View File

@ -16,7 +16,7 @@
}),
'max_retries': 2,
'max_tokens': 100,
'model_name': 'mixtral-8x7b-32768',
'model_name': 'llama-3.1-8b-instant',
'n': 1,
'request_timeout': 60.0,
'stop': list([

View File

@ -2,6 +2,7 @@
import json
import os
import warnings
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -156,7 +157,7 @@ def mock_completion() -> dict:
def test_groq_invoke(mock_completion: dict) -> None:
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model="foo")
mock_client = MagicMock()
completed = False
@ -178,7 +179,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
async def test_groq_ainvoke(mock_completion: dict) -> None:
llm = ChatGroq() # type: ignore[call-arg]
llm = ChatGroq(model="foo")
mock_client = AsyncMock()
completed = False
@ -203,7 +204,7 @@ def test_chat_groq_extra_kwargs() -> None:
"""Test extra kwargs to chat groq."""
# Check that foo is saved in extra_kwargs.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(foo=3, max_tokens=10) # type: ignore[call-arg]
llm = ChatGroq(model="foo", foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
assert len(record) == 1
@ -212,7 +213,7 @@ def test_chat_groq_extra_kwargs() -> None:
# Test that if extra_kwargs are provided, they are added to it.
with pytest.warns(UserWarning) as record:
llm = ChatGroq(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
llm = ChatGroq(model="foo", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
assert len(record) == 1
assert type(record[0].message) is UserWarning
@ -220,21 +221,22 @@ def test_chat_groq_extra_kwargs() -> None:
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatGroq(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
ChatGroq(model="foo", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatGroq(model_kwargs={"temperature": 0.2}) # type: ignore[call-arg]
ChatGroq(model="foo", model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatGroq(model_kwargs={"model": "test-model"}) # type: ignore[call-arg]
ChatGroq(model="foo", model_kwargs={"model": "test-model"})
def test_chat_groq_invalid_streaming_params() -> None:
"""Test that an error is raised if streaming is invoked with n>1."""
with pytest.raises(ValueError):
ChatGroq( # type: ignore[call-arg]
ChatGroq(
model="foo",
max_tokens=10,
streaming=True,
temperature=0,
@ -246,7 +248,7 @@ def test_chat_groq_secret() -> None:
"""Test that secret is not printed"""
secret = "secretKey"
not_secret = "safe"
llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
stringified = str(llm)
assert not_secret in stringified
assert secret not in stringified
@ -257,7 +259,7 @@ def test_groq_serialization() -> None:
"""Test that ChatGroq can be successfully serialized and deserialized"""
api_key1 = "top secret"
api_key2 = "topest secret"
llm = ChatGroq(api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
dump = lc_load.dumps(llm)
llm2 = lc_load.loads(
dump,
@ -278,3 +280,23 @@ def test_groq_serialization() -> None:
# Ensure a None was preserved
assert llm.groq_api_base == llm2.groq_api_base
def test_groq_warns_default_model() -> None:
"""Test that a warning is raised if a default model is used."""
# Delete this test in 0.3 release, when the default model is removed.
# Test no warning if model is specified
with warnings.catch_warnings():
warnings.simplefilter("error")
ChatGroq(model="foo")
# Test warns if default model is used
with pytest.warns(match="default model"):
ChatGroq()
# Test only warns once
with warnings.catch_warnings():
warnings.simplefilter("error")
ChatGroq()

View File

@ -14,3 +14,7 @@ class TestGroqStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq
@property
def chat_model_params(self) -> dict:
return {"model": "llama-3.1-8b-instant"}