mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
groq[patch]: warn if model is not specified (#30161)
Groq is retiring `mixtral-8x7b-32768`, which is currently the default model for ChatGroq, on March 20. Here we emit a warning if the model is not specified explicitly. A version 0.3.0 will be released ahead of March 20 that removes the default altogether.
This commit is contained in:
parent
3444e587ee
commit
74e7772a5f
@ -88,6 +88,8 @@ from typing_extensions import Self
|
||||
|
||||
from langchain_groq.version import __version__
|
||||
|
||||
WARNED_DEFAULT_MODEL = False
|
||||
|
||||
|
||||
class ChatGroq(BaseChatModel):
|
||||
"""`Groq` Chat large language models API.
|
||||
@ -109,7 +111,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
Name of Groq model to use. E.g. "mixtral-8x7b-32768".
|
||||
Name of Groq model to use. E.g. "llama-3.1-8b-instant".
|
||||
temperature: float
|
||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
||||
max_tokens: Optional[int]
|
||||
@ -140,7 +142,7 @@ class ChatGroq(BaseChatModel):
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
llm = ChatGroq(
|
||||
model="mixtral-8x7b-32768",
|
||||
model="llama-3.1-8b-instant",
|
||||
temperature=0.0,
|
||||
max_retries=2,
|
||||
# other params...
|
||||
@ -164,7 +166,7 @@ class ChatGroq(BaseChatModel):
|
||||
response_metadata={'token_usage': {'completion_tokens': 38,
|
||||
'prompt_tokens': 28, 'total_tokens': 66, 'completion_time':
|
||||
0.057975474, 'prompt_time': 0.005366091, 'queue_time': None,
|
||||
'total_time': 0.063341565}, 'model_name': 'mixtral-8x7b-32768',
|
||||
'total_time': 0.063341565}, 'model_name': 'llama-3.1-8b-instant',
|
||||
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
||||
'logprobs': None}, id='run-ecc71d70-e10c-4b69-8b8c-b8027d95d4b8-0')
|
||||
|
||||
@ -222,7 +224,7 @@ class ChatGroq(BaseChatModel):
|
||||
response_metadata={'token_usage': {'completion_tokens': 53,
|
||||
'prompt_tokens': 28, 'total_tokens': 81, 'completion_time':
|
||||
0.083623752, 'prompt_time': 0.007365126, 'queue_time': None,
|
||||
'total_time': 0.090988878}, 'model_name': 'mixtral-8x7b-32768',
|
||||
'total_time': 0.090988878}, 'model_name': 'llama-3.1-8b-instant',
|
||||
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
||||
'logprobs': None}, id='run-897f3391-1bea-42e2-82e0-686e2367bcf8-0')
|
||||
|
||||
@ -295,7 +297,7 @@ class ChatGroq(BaseChatModel):
|
||||
'prompt_time': 0.007518279,
|
||||
'queue_time': None,
|
||||
'total_time': 0.11947467},
|
||||
'model_name': 'mixtral-8x7b-32768',
|
||||
'model_name': 'llama-3.1-8b-instant',
|
||||
'system_fingerprint': 'fp_c5f20b5bb1',
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None}
|
||||
@ -351,6 +353,27 @@ class ChatGroq(BaseChatModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_default_model(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Warning anticipating removal of default model."""
|
||||
# TODO(ccurme): remove this warning in 0.3.0 when default model is removed
|
||||
global WARNED_DEFAULT_MODEL
|
||||
if (
|
||||
"model" not in values
|
||||
and "model_name" not in values
|
||||
and not WARNED_DEFAULT_MODEL
|
||||
):
|
||||
warnings.warn(
|
||||
"Groq is retiring the default model for ChatGroq, mixtral-8x7b-32768, "
|
||||
"on March 20, 2025. Requests with the default model will start failing "
|
||||
"on that date. Version 0.3.0 of langchain-groq will remove the "
|
||||
"default. Please specify `model` explicitly, e.g., "
|
||||
"`model='mistral-saba-24b'` or `model='llama-3.3-70b-versatile'`.",
|
||||
)
|
||||
WARNED_DEFAULT_MODEL = True
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
|
@ -21,6 +21,8 @@ from tests.unit_tests.fake.callbacks import (
|
||||
FakeCallbackHandlerWithChatStart,
|
||||
)
|
||||
|
||||
MODEL_NAME = "llama-3.3-70b-versatile"
|
||||
|
||||
|
||||
#
|
||||
# Smoke test Runnable interface
|
||||
@ -28,7 +30,8 @@ from tests.unit_tests.fake.callbacks import (
|
||||
@pytest.mark.scheduled
|
||||
def test_invoke() -> None:
|
||||
"""Test Chat wrapper."""
|
||||
chat = ChatGroq( # type: ignore[call-arg]
|
||||
chat = ChatGroq(
|
||||
model=MODEL_NAME,
|
||||
temperature=0.7,
|
||||
base_url=None,
|
||||
groq_proxy=None,
|
||||
@ -49,7 +52,7 @@ def test_invoke() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test ainvoke tokens from ChatGroq."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
|
||||
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
|
||||
assert isinstance(result, BaseMessage)
|
||||
@ -59,7 +62,7 @@ async def test_ainvoke() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatGroq."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
|
||||
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
|
||||
for token in result:
|
||||
@ -70,7 +73,7 @@ def test_batch() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_abatch() -> None:
|
||||
"""Test abatch tokens from ChatGroq."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
|
||||
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
|
||||
for token in result:
|
||||
@ -81,7 +84,7 @@ async def test_abatch() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_stream() -> None:
|
||||
"""Test streaming tokens from Groq."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
|
||||
for token in chat.stream("Welcome to the Groqetship!"):
|
||||
assert isinstance(token, BaseMessageChunk)
|
||||
@ -91,7 +94,7 @@ async def test_stream() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from Groq."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
@ -124,7 +127,7 @@ async def test_astream() -> None:
|
||||
def test_generate() -> None:
|
||||
"""Test sync generate."""
|
||||
n = 1
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
message = HumanMessage(content="Hello", n=1)
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -143,7 +146,7 @@ def test_generate() -> None:
|
||||
async def test_agenerate() -> None:
|
||||
"""Test async generation."""
|
||||
n = 1
|
||||
chat = ChatGroq(max_tokens=10, n=1) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10, n=1)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -165,7 +168,8 @@ async def test_agenerate() -> None:
|
||||
def test_invoke_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
chat = ChatGroq( # type: ignore[call-arg]
|
||||
chat = ChatGroq(
|
||||
model=MODEL_NAME,
|
||||
max_tokens=2,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -181,7 +185,8 @@ def test_invoke_streaming() -> None:
|
||||
async def test_agenerate_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandlerWithChatStart()
|
||||
chat = ChatGroq( # type: ignore[call-arg]
|
||||
chat = ChatGroq(
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -220,7 +225,8 @@ def test_streaming_generation_info() -> None:
|
||||
self.saved_things["generation"] = args[0]
|
||||
|
||||
callback = _FakeCallback()
|
||||
chat = ChatGroq( # type: ignore[call-arg]
|
||||
chat = ChatGroq(
|
||||
model=MODEL_NAME,
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callbacks=[callback],
|
||||
@ -234,7 +240,7 @@ def test_streaming_generation_info() -> None:
|
||||
|
||||
def test_system_message() -> None:
|
||||
"""Test ChatGroq wrapper with system message."""
|
||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([system_message, human_message])
|
||||
@ -242,10 +248,9 @@ def test_system_message() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||
def test_tool_choice() -> None:
|
||||
"""Test that tool choice is respected."""
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model=MODEL_NAME)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
@ -273,10 +278,9 @@ def test_tool_choice() -> None:
|
||||
assert tool_call["args"] == {"name": "Erick", "age": 27}
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||
def test_tool_choice_bool() -> None:
|
||||
"""Test that tool choice is respected just passing in True."""
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model=MODEL_NAME)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
@ -301,7 +305,7 @@ def test_tool_choice_bool() -> None:
|
||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||
def test_streaming_tool_call() -> None:
|
||||
"""Test that tool choice is respected."""
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model=MODEL_NAME)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
@ -339,7 +343,7 @@ def test_streaming_tool_call() -> None:
|
||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||
async def test_astreaming_tool_call() -> None:
|
||||
"""Test that tool choice is respected."""
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model=MODEL_NAME)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
@ -384,7 +388,7 @@ def test_json_mode_structured_output() -> None:
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
chat = ChatGroq().with_structured_output(Joke, method="json_mode") # type: ignore[call-arg]
|
||||
chat = ChatGroq(model=MODEL_NAME).with_structured_output(Joke, method="json_mode")
|
||||
result = chat.invoke(
|
||||
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
||||
)
|
||||
|
@ -16,7 +16,7 @@
|
||||
}),
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'mixtral-8x7b-32768',
|
||||
'model_name': 'llama-3.1-8b-instant',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -156,7 +157,7 @@ def mock_completion() -> dict:
|
||||
|
||||
|
||||
def test_groq_invoke(mock_completion: dict) -> None:
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model="foo")
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
@ -178,7 +179,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
|
||||
|
||||
|
||||
async def test_groq_ainvoke(mock_completion: dict) -> None:
|
||||
llm = ChatGroq() # type: ignore[call-arg]
|
||||
llm = ChatGroq(model="foo")
|
||||
mock_client = AsyncMock()
|
||||
completed = False
|
||||
|
||||
@ -203,7 +204,7 @@ def test_chat_groq_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat groq."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
with pytest.warns(UserWarning) as record:
|
||||
llm = ChatGroq(foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
llm = ChatGroq(model="foo", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
assert len(record) == 1
|
||||
@ -212,7 +213,7 @@ def test_chat_groq_extra_kwargs() -> None:
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
with pytest.warns(UserWarning) as record:
|
||||
llm = ChatGroq(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
llm = ChatGroq(model="foo", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
assert len(record) == 1
|
||||
assert type(record[0].message) is UserWarning
|
||||
@ -220,21 +221,22 @@ def test_chat_groq_extra_kwargs() -> None:
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
ChatGroq(model="foo", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq(model_kwargs={"temperature": 0.2}) # type: ignore[call-arg]
|
||||
ChatGroq(model="foo", model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq(model_kwargs={"model": "test-model"}) # type: ignore[call-arg]
|
||||
ChatGroq(model="foo", model_kwargs={"model": "test-model"})
|
||||
|
||||
|
||||
def test_chat_groq_invalid_streaming_params() -> None:
|
||||
"""Test that an error is raised if streaming is invoked with n>1."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatGroq( # type: ignore[call-arg]
|
||||
ChatGroq(
|
||||
model="foo",
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -246,7 +248,7 @@ def test_chat_groq_secret() -> None:
|
||||
"""Test that secret is not printed"""
|
||||
secret = "secretKey"
|
||||
not_secret = "safe"
|
||||
llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
||||
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
||||
stringified = str(llm)
|
||||
assert not_secret in stringified
|
||||
assert secret not in stringified
|
||||
@ -257,7 +259,7 @@ def test_groq_serialization() -> None:
|
||||
"""Test that ChatGroq can be successfully serialized and deserialized"""
|
||||
api_key1 = "top secret"
|
||||
api_key2 = "topest secret"
|
||||
llm = ChatGroq(api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
||||
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
||||
dump = lc_load.dumps(llm)
|
||||
llm2 = lc_load.loads(
|
||||
dump,
|
||||
@ -278,3 +280,23 @@ def test_groq_serialization() -> None:
|
||||
|
||||
# Ensure a None was preserved
|
||||
assert llm.groq_api_base == llm2.groq_api_base
|
||||
|
||||
|
||||
def test_groq_warns_default_model() -> None:
|
||||
"""Test that a warning is raised if a default model is used."""
|
||||
|
||||
# Delete this test in 0.3 release, when the default model is removed.
|
||||
|
||||
# Test no warning if model is specified
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
ChatGroq(model="foo")
|
||||
|
||||
# Test warns if default model is used
|
||||
with pytest.warns(match="default model"):
|
||||
ChatGroq()
|
||||
|
||||
# Test only warns once
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
ChatGroq()
|
||||
|
@ -14,3 +14,7 @@ class TestGroqStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "llama-3.1-8b-instant"}
|
||||
|
Loading…
Reference in New Issue
Block a user