From 74e7772a5fe6033a7c650130799d92f113630d87 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 7 Mar 2025 15:21:13 -0500 Subject: [PATCH] groq[patch]: warn if model is not specified (#30161) Groq is retiring `mixtral-8x7b-32768`, which is currently the default model for ChatGroq, on March 20. Here we emit a warning if the model is not specified explicitly. A version 0.3.0 will be released ahead of March 20 that removes the default altogether. --- .../groq/langchain_groq/chat_models.py | 33 ++++++++++++--- .../integration_tests/test_chat_models.py | 42 ++++++++++--------- .../__snapshots__/test_standard.ambr | 2 +- .../groq/tests/unit_tests/test_chat_models.py | 42 ++++++++++++++----- .../groq/tests/unit_tests/test_standard.py | 4 ++ 5 files changed, 88 insertions(+), 35 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index ab89354e152..bb4115b850a 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -88,6 +88,8 @@ from typing_extensions import Self from langchain_groq.version import __version__ +WARNED_DEFAULT_MODEL = False + class ChatGroq(BaseChatModel): """`Groq` Chat large language models API. @@ -109,7 +111,7 @@ class ChatGroq(BaseChatModel): Key init args — completion params: model: str - Name of Groq model to use. E.g. "mixtral-8x7b-32768". + Name of Groq model to use. E.g. "llama-3.1-8b-instant". temperature: float Sampling temperature. Ranges from 0.0 to 1.0. max_tokens: Optional[int] @@ -140,7 +142,7 @@ class ChatGroq(BaseChatModel): from langchain_groq import ChatGroq llm = ChatGroq( - model="mixtral-8x7b-32768", + model="llama-3.1-8b-instant", temperature=0.0, max_retries=2, # other params... @@ -164,7 +166,7 @@ class ChatGroq(BaseChatModel): response_metadata={'token_usage': {'completion_tokens': 38, 'prompt_tokens': 28, 'total_tokens': 66, 'completion_time': 0.057975474, 'prompt_time': 0.005366091, 'queue_time': None, - 'total_time': 0.063341565}, 'model_name': 'mixtral-8x7b-32768', + 'total_time': 0.063341565}, 'model_name': 'llama-3.1-8b-instant', 'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop', 'logprobs': None}, id='run-ecc71d70-e10c-4b69-8b8c-b8027d95d4b8-0') @@ -222,7 +224,7 @@ class ChatGroq(BaseChatModel): response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 28, 'total_tokens': 81, 'completion_time': 0.083623752, 'prompt_time': 0.007365126, 'queue_time': None, - 'total_time': 0.090988878}, 'model_name': 'mixtral-8x7b-32768', + 'total_time': 0.090988878}, 'model_name': 'llama-3.1-8b-instant', 'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop', 'logprobs': None}, id='run-897f3391-1bea-42e2-82e0-686e2367bcf8-0') @@ -295,7 +297,7 @@ class ChatGroq(BaseChatModel): 'prompt_time': 0.007518279, 'queue_time': None, 'total_time': 0.11947467}, - 'model_name': 'mixtral-8x7b-32768', + 'model_name': 'llama-3.1-8b-instant', 'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop', 'logprobs': None} @@ -351,6 +353,27 @@ class ChatGroq(BaseChatModel): populate_by_name=True, ) + @model_validator(mode="before") + @classmethod + def warn_default_model(cls, values: Dict[str, Any]) -> Any: + """Warning anticipating removal of default model.""" + # TODO(ccurme): remove this warning in 0.3.0 when default model is removed + global WARNED_DEFAULT_MODEL + if ( + "model" not in values + and "model_name" not in values + and not WARNED_DEFAULT_MODEL + ): + warnings.warn( + "Groq is retiring the default model for ChatGroq, mixtral-8x7b-32768, " + "on March 20, 2025. Requests with the default model will start failing " + "on that date. Version 0.3.0 of langchain-groq will remove the " + "default. Please specify `model` explicitly, e.g., " + "`model='mistral-saba-24b'` or `model='llama-3.3-70b-versatile'`.", + ) + WARNED_DEFAULT_MODEL = True + return values + @model_validator(mode="before") @classmethod def build_extra(cls, values: Dict[str, Any]) -> Any: diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 9113c829871..9d74ef4f2ab 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -21,6 +21,8 @@ from tests.unit_tests.fake.callbacks import ( FakeCallbackHandlerWithChatStart, ) +MODEL_NAME = "llama-3.3-70b-versatile" + # # Smoke test Runnable interface @@ -28,7 +30,8 @@ from tests.unit_tests.fake.callbacks import ( @pytest.mark.scheduled def test_invoke() -> None: """Test Chat wrapper.""" - chat = ChatGroq( # type: ignore[call-arg] + chat = ChatGroq( + model=MODEL_NAME, temperature=0.7, base_url=None, groq_proxy=None, @@ -49,7 +52,7 @@ def test_invoke() -> None: @pytest.mark.scheduled async def test_ainvoke() -> None: """Test ainvoke tokens from ChatGroq.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]}) assert isinstance(result, BaseMessage) @@ -59,7 +62,7 @@ async def test_ainvoke() -> None: @pytest.mark.scheduled def test_batch() -> None: """Test batch tokens from ChatGroq.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) result = chat.batch(["Hello!", "Welcome to the Groqetship!"]) for token in result: @@ -70,7 +73,7 @@ def test_batch() -> None: @pytest.mark.scheduled async def test_abatch() -> None: """Test abatch tokens from ChatGroq.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"]) for token in result: @@ -81,7 +84,7 @@ async def test_abatch() -> None: @pytest.mark.scheduled async def test_stream() -> None: """Test streaming tokens from Groq.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) for token in chat.stream("Welcome to the Groqetship!"): assert isinstance(token, BaseMessageChunk) @@ -91,7 +94,7 @@ async def test_stream() -> None: @pytest.mark.scheduled async def test_astream() -> None: """Test streaming tokens from Groq.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 @@ -124,7 +127,7 @@ async def test_astream() -> None: def test_generate() -> None: """Test sync generate.""" n = 1 - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) message = HumanMessage(content="Hello", n=1) response = chat.generate([[message], [message]]) assert isinstance(response, LLMResult) @@ -143,7 +146,7 @@ def test_generate() -> None: async def test_agenerate() -> None: """Test async generation.""" n = 1 - chat = ChatGroq(max_tokens=10, n=1) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10, n=1) message = HumanMessage(content="Hello") response = await chat.agenerate([[message], [message]]) assert isinstance(response, LLMResult) @@ -165,7 +168,8 @@ async def test_agenerate() -> None: def test_invoke_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() - chat = ChatGroq( # type: ignore[call-arg] + chat = ChatGroq( + model=MODEL_NAME, max_tokens=2, streaming=True, temperature=0, @@ -181,7 +185,8 @@ def test_invoke_streaming() -> None: async def test_agenerate_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandlerWithChatStart() - chat = ChatGroq( # type: ignore[call-arg] + chat = ChatGroq( + model=MODEL_NAME, max_tokens=10, streaming=True, temperature=0, @@ -220,7 +225,8 @@ def test_streaming_generation_info() -> None: self.saved_things["generation"] = args[0] callback = _FakeCallback() - chat = ChatGroq( # type: ignore[call-arg] + chat = ChatGroq( + model=MODEL_NAME, max_tokens=2, temperature=0, callbacks=[callback], @@ -234,7 +240,7 @@ def test_streaming_generation_info() -> None: def test_system_message() -> None: """Test ChatGroq wrapper with system message.""" - chat = ChatGroq(max_tokens=10) # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME, max_tokens=10) system_message = SystemMessage(content="You are to chat with the user.") human_message = HumanMessage(content="Hello") response = chat.invoke([system_message, human_message]) @@ -242,10 +248,9 @@ def test_system_message() -> None: assert isinstance(response.content, str) -@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") def test_tool_choice() -> None: """Test that tool choice is respected.""" - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model=MODEL_NAME) class MyTool(BaseModel): name: str @@ -273,10 +278,9 @@ def test_tool_choice() -> None: assert tool_call["args"] == {"name": "Erick", "age": 27} -@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") def test_tool_choice_bool() -> None: """Test that tool choice is respected just passing in True.""" - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model=MODEL_NAME) class MyTool(BaseModel): name: str @@ -301,7 +305,7 @@ def test_tool_choice_bool() -> None: @pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") def test_streaming_tool_call() -> None: """Test that tool choice is respected.""" - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model=MODEL_NAME) class MyTool(BaseModel): name: str @@ -339,7 +343,7 @@ def test_streaming_tool_call() -> None: @pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") async def test_astreaming_tool_call() -> None: """Test that tool choice is respected.""" - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model=MODEL_NAME) class MyTool(BaseModel): name: str @@ -384,7 +388,7 @@ def test_json_mode_structured_output() -> None: setup: str = Field(description="question to set up a joke") punchline: str = Field(description="answer to resolve the joke") - chat = ChatGroq().with_structured_output(Joke, method="json_mode") # type: ignore[call-arg] + chat = ChatGroq(model=MODEL_NAME).with_structured_output(Joke, method="json_mode") result = chat.invoke( "Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys" ) diff --git a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr index 741d2c84745..7b8db708167 100644 --- a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr @@ -16,7 +16,7 @@ }), 'max_retries': 2, 'max_tokens': 100, - 'model_name': 'mixtral-8x7b-32768', + 'model_name': 'llama-3.1-8b-instant', 'n': 1, 'request_timeout': 60.0, 'stop': list([ diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 1d7ec63c514..583b562b285 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -2,6 +2,7 @@ import json import os +import warnings from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -156,7 +157,7 @@ def mock_completion() -> dict: def test_groq_invoke(mock_completion: dict) -> None: - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model="foo") mock_client = MagicMock() completed = False @@ -178,7 +179,7 @@ def test_groq_invoke(mock_completion: dict) -> None: async def test_groq_ainvoke(mock_completion: dict) -> None: - llm = ChatGroq() # type: ignore[call-arg] + llm = ChatGroq(model="foo") mock_client = AsyncMock() completed = False @@ -203,7 +204,7 @@ def test_chat_groq_extra_kwargs() -> None: """Test extra kwargs to chat groq.""" # Check that foo is saved in extra_kwargs. with pytest.warns(UserWarning) as record: - llm = ChatGroq(foo=3, max_tokens=10) # type: ignore[call-arg] + llm = ChatGroq(model="foo", foo=3, max_tokens=10) # type: ignore[call-arg] assert llm.max_tokens == 10 assert llm.model_kwargs == {"foo": 3} assert len(record) == 1 @@ -212,7 +213,7 @@ def test_chat_groq_extra_kwargs() -> None: # Test that if extra_kwargs are provided, they are added to it. with pytest.warns(UserWarning) as record: - llm = ChatGroq(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] + llm = ChatGroq(model="foo", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] assert llm.model_kwargs == {"foo": 3, "bar": 2} assert len(record) == 1 assert type(record[0].message) is UserWarning @@ -220,21 +221,22 @@ def test_chat_groq_extra_kwargs() -> None: # Test that if provided twice it errors with pytest.raises(ValueError): - ChatGroq(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] + ChatGroq(model="foo", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] # Test that if explicit param is specified in kwargs it errors with pytest.raises(ValueError): - ChatGroq(model_kwargs={"temperature": 0.2}) # type: ignore[call-arg] + ChatGroq(model="foo", model_kwargs={"temperature": 0.2}) # Test that "model" cannot be specified in kwargs with pytest.raises(ValueError): - ChatGroq(model_kwargs={"model": "test-model"}) # type: ignore[call-arg] + ChatGroq(model="foo", model_kwargs={"model": "test-model"}) def test_chat_groq_invalid_streaming_params() -> None: """Test that an error is raised if streaming is invoked with n>1.""" with pytest.raises(ValueError): - ChatGroq( # type: ignore[call-arg] + ChatGroq( + model="foo", max_tokens=10, streaming=True, temperature=0, @@ -246,7 +248,7 @@ def test_chat_groq_secret() -> None: """Test that secret is not printed""" secret = "secretKey" not_secret = "safe" - llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type] + llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type] stringified = str(llm) assert not_secret in stringified assert secret not in stringified @@ -257,7 +259,7 @@ def test_groq_serialization() -> None: """Test that ChatGroq can be successfully serialized and deserialized""" api_key1 = "top secret" api_key2 = "topest secret" - llm = ChatGroq(api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type] + llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type] dump = lc_load.dumps(llm) llm2 = lc_load.loads( dump, @@ -278,3 +280,23 @@ def test_groq_serialization() -> None: # Ensure a None was preserved assert llm.groq_api_base == llm2.groq_api_base + + +def test_groq_warns_default_model() -> None: + """Test that a warning is raised if a default model is used.""" + + # Delete this test in 0.3 release, when the default model is removed. + + # Test no warning if model is specified + with warnings.catch_warnings(): + warnings.simplefilter("error") + ChatGroq(model="foo") + + # Test warns if default model is used + with pytest.warns(match="default model"): + ChatGroq() + + # Test only warns once + with warnings.catch_warnings(): + warnings.simplefilter("error") + ChatGroq() diff --git a/libs/partners/groq/tests/unit_tests/test_standard.py b/libs/partners/groq/tests/unit_tests/test_standard.py index e4df2916f30..f04d13b703f 100644 --- a/libs/partners/groq/tests/unit_tests/test_standard.py +++ b/libs/partners/groq/tests/unit_tests/test_standard.py @@ -14,3 +14,7 @@ class TestGroqStandard(ChatModelUnitTests): @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq + + @property + def chat_model_params(self) -> dict: + return {"model": "llama-3.1-8b-instant"}