mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
groq[patch]: warn if model is not specified (#30161)
Groq is retiring `mixtral-8x7b-32768`, which is currently the default model for ChatGroq, on March 20. Here we emit a warning if the model is not specified explicitly. A version 0.3.0 will be released ahead of March 20 that removes the default altogether.
This commit is contained in:
parent
3444e587ee
commit
74e7772a5f
@ -88,6 +88,8 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from langchain_groq.version import __version__
|
from langchain_groq.version import __version__
|
||||||
|
|
||||||
|
WARNED_DEFAULT_MODEL = False
|
||||||
|
|
||||||
|
|
||||||
class ChatGroq(BaseChatModel):
|
class ChatGroq(BaseChatModel):
|
||||||
"""`Groq` Chat large language models API.
|
"""`Groq` Chat large language models API.
|
||||||
@ -109,7 +111,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
Key init args — completion params:
|
Key init args — completion params:
|
||||||
model: str
|
model: str
|
||||||
Name of Groq model to use. E.g. "mixtral-8x7b-32768".
|
Name of Groq model to use. E.g. "llama-3.1-8b-instant".
|
||||||
temperature: float
|
temperature: float
|
||||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
Sampling temperature. Ranges from 0.0 to 1.0.
|
||||||
max_tokens: Optional[int]
|
max_tokens: Optional[int]
|
||||||
@ -140,7 +142,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
|
|
||||||
llm = ChatGroq(
|
llm = ChatGroq(
|
||||||
model="mixtral-8x7b-32768",
|
model="llama-3.1-8b-instant",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_retries=2,
|
max_retries=2,
|
||||||
# other params...
|
# other params...
|
||||||
@ -164,7 +166,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
response_metadata={'token_usage': {'completion_tokens': 38,
|
response_metadata={'token_usage': {'completion_tokens': 38,
|
||||||
'prompt_tokens': 28, 'total_tokens': 66, 'completion_time':
|
'prompt_tokens': 28, 'total_tokens': 66, 'completion_time':
|
||||||
0.057975474, 'prompt_time': 0.005366091, 'queue_time': None,
|
0.057975474, 'prompt_time': 0.005366091, 'queue_time': None,
|
||||||
'total_time': 0.063341565}, 'model_name': 'mixtral-8x7b-32768',
|
'total_time': 0.063341565}, 'model_name': 'llama-3.1-8b-instant',
|
||||||
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
||||||
'logprobs': None}, id='run-ecc71d70-e10c-4b69-8b8c-b8027d95d4b8-0')
|
'logprobs': None}, id='run-ecc71d70-e10c-4b69-8b8c-b8027d95d4b8-0')
|
||||||
|
|
||||||
@ -222,7 +224,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
response_metadata={'token_usage': {'completion_tokens': 53,
|
response_metadata={'token_usage': {'completion_tokens': 53,
|
||||||
'prompt_tokens': 28, 'total_tokens': 81, 'completion_time':
|
'prompt_tokens': 28, 'total_tokens': 81, 'completion_time':
|
||||||
0.083623752, 'prompt_time': 0.007365126, 'queue_time': None,
|
0.083623752, 'prompt_time': 0.007365126, 'queue_time': None,
|
||||||
'total_time': 0.090988878}, 'model_name': 'mixtral-8x7b-32768',
|
'total_time': 0.090988878}, 'model_name': 'llama-3.1-8b-instant',
|
||||||
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop',
|
||||||
'logprobs': None}, id='run-897f3391-1bea-42e2-82e0-686e2367bcf8-0')
|
'logprobs': None}, id='run-897f3391-1bea-42e2-82e0-686e2367bcf8-0')
|
||||||
|
|
||||||
@ -295,7 +297,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
'prompt_time': 0.007518279,
|
'prompt_time': 0.007518279,
|
||||||
'queue_time': None,
|
'queue_time': None,
|
||||||
'total_time': 0.11947467},
|
'total_time': 0.11947467},
|
||||||
'model_name': 'mixtral-8x7b-32768',
|
'model_name': 'llama-3.1-8b-instant',
|
||||||
'system_fingerprint': 'fp_c5f20b5bb1',
|
'system_fingerprint': 'fp_c5f20b5bb1',
|
||||||
'finish_reason': 'stop',
|
'finish_reason': 'stop',
|
||||||
'logprobs': None}
|
'logprobs': None}
|
||||||
@ -351,6 +353,27 @@ class ChatGroq(BaseChatModel):
|
|||||||
populate_by_name=True,
|
populate_by_name=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def warn_default_model(cls, values: Dict[str, Any]) -> Any:
|
||||||
|
"""Warning anticipating removal of default model."""
|
||||||
|
# TODO(ccurme): remove this warning in 0.3.0 when default model is removed
|
||||||
|
global WARNED_DEFAULT_MODEL
|
||||||
|
if (
|
||||||
|
"model" not in values
|
||||||
|
and "model_name" not in values
|
||||||
|
and not WARNED_DEFAULT_MODEL
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Groq is retiring the default model for ChatGroq, mixtral-8x7b-32768, "
|
||||||
|
"on March 20, 2025. Requests with the default model will start failing "
|
||||||
|
"on that date. Version 0.3.0 of langchain-groq will remove the "
|
||||||
|
"default. Please specify `model` explicitly, e.g., "
|
||||||
|
"`model='mistral-saba-24b'` or `model='llama-3.3-70b-versatile'`.",
|
||||||
|
)
|
||||||
|
WARNED_DEFAULT_MODEL = True
|
||||||
|
return values
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||||
|
@ -21,6 +21,8 @@ from tests.unit_tests.fake.callbacks import (
|
|||||||
FakeCallbackHandlerWithChatStart,
|
FakeCallbackHandlerWithChatStart,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_NAME = "llama-3.3-70b-versatile"
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Smoke test Runnable interface
|
# Smoke test Runnable interface
|
||||||
@ -28,7 +30,8 @@ from tests.unit_tests.fake.callbacks import (
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
def test_invoke() -> None:
|
def test_invoke() -> None:
|
||||||
"""Test Chat wrapper."""
|
"""Test Chat wrapper."""
|
||||||
chat = ChatGroq( # type: ignore[call-arg]
|
chat = ChatGroq(
|
||||||
|
model=MODEL_NAME,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
base_url=None,
|
base_url=None,
|
||||||
groq_proxy=None,
|
groq_proxy=None,
|
||||||
@ -49,7 +52,7 @@ def test_invoke() -> None:
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
async def test_ainvoke() -> None:
|
async def test_ainvoke() -> None:
|
||||||
"""Test ainvoke tokens from ChatGroq."""
|
"""Test ainvoke tokens from ChatGroq."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
|
|
||||||
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
|
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
|
||||||
assert isinstance(result, BaseMessage)
|
assert isinstance(result, BaseMessage)
|
||||||
@ -59,7 +62,7 @@ async def test_ainvoke() -> None:
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
"""Test batch tokens from ChatGroq."""
|
"""Test batch tokens from ChatGroq."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
|
|
||||||
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
|
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
|
||||||
for token in result:
|
for token in result:
|
||||||
@ -70,7 +73,7 @@ def test_batch() -> None:
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
"""Test abatch tokens from ChatGroq."""
|
"""Test abatch tokens from ChatGroq."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
|
|
||||||
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
|
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
|
||||||
for token in result:
|
for token in result:
|
||||||
@ -81,7 +84,7 @@ async def test_abatch() -> None:
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
async def test_stream() -> None:
|
async def test_stream() -> None:
|
||||||
"""Test streaming tokens from Groq."""
|
"""Test streaming tokens from Groq."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
|
|
||||||
for token in chat.stream("Welcome to the Groqetship!"):
|
for token in chat.stream("Welcome to the Groqetship!"):
|
||||||
assert isinstance(token, BaseMessageChunk)
|
assert isinstance(token, BaseMessageChunk)
|
||||||
@ -91,7 +94,7 @@ async def test_stream() -> None:
|
|||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
"""Test streaming tokens from Groq."""
|
"""Test streaming tokens from Groq."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
full: Optional[BaseMessageChunk] = None
|
||||||
chunks_with_token_counts = 0
|
chunks_with_token_counts = 0
|
||||||
@ -124,7 +127,7 @@ async def test_astream() -> None:
|
|||||||
def test_generate() -> None:
|
def test_generate() -> None:
|
||||||
"""Test sync generate."""
|
"""Test sync generate."""
|
||||||
n = 1
|
n = 1
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
message = HumanMessage(content="Hello", n=1)
|
message = HumanMessage(content="Hello", n=1)
|
||||||
response = chat.generate([[message], [message]])
|
response = chat.generate([[message], [message]])
|
||||||
assert isinstance(response, LLMResult)
|
assert isinstance(response, LLMResult)
|
||||||
@ -143,7 +146,7 @@ def test_generate() -> None:
|
|||||||
async def test_agenerate() -> None:
|
async def test_agenerate() -> None:
|
||||||
"""Test async generation."""
|
"""Test async generation."""
|
||||||
n = 1
|
n = 1
|
||||||
chat = ChatGroq(max_tokens=10, n=1) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10, n=1)
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
response = await chat.agenerate([[message], [message]])
|
response = await chat.agenerate([[message], [message]])
|
||||||
assert isinstance(response, LLMResult)
|
assert isinstance(response, LLMResult)
|
||||||
@ -165,7 +168,8 @@ async def test_agenerate() -> None:
|
|||||||
def test_invoke_streaming() -> None:
|
def test_invoke_streaming() -> None:
|
||||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
callback_handler = FakeCallbackHandler()
|
callback_handler = FakeCallbackHandler()
|
||||||
chat = ChatGroq( # type: ignore[call-arg]
|
chat = ChatGroq(
|
||||||
|
model=MODEL_NAME,
|
||||||
max_tokens=2,
|
max_tokens=2,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
@ -181,7 +185,8 @@ def test_invoke_streaming() -> None:
|
|||||||
async def test_agenerate_streaming() -> None:
|
async def test_agenerate_streaming() -> None:
|
||||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
callback_handler = FakeCallbackHandlerWithChatStart()
|
callback_handler = FakeCallbackHandlerWithChatStart()
|
||||||
chat = ChatGroq( # type: ignore[call-arg]
|
chat = ChatGroq(
|
||||||
|
model=MODEL_NAME,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
@ -220,7 +225,8 @@ def test_streaming_generation_info() -> None:
|
|||||||
self.saved_things["generation"] = args[0]
|
self.saved_things["generation"] = args[0]
|
||||||
|
|
||||||
callback = _FakeCallback()
|
callback = _FakeCallback()
|
||||||
chat = ChatGroq( # type: ignore[call-arg]
|
chat = ChatGroq(
|
||||||
|
model=MODEL_NAME,
|
||||||
max_tokens=2,
|
max_tokens=2,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
@ -234,7 +240,7 @@ def test_streaming_generation_info() -> None:
|
|||||||
|
|
||||||
def test_system_message() -> None:
|
def test_system_message() -> None:
|
||||||
"""Test ChatGroq wrapper with system message."""
|
"""Test ChatGroq wrapper with system message."""
|
||||||
chat = ChatGroq(max_tokens=10) # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
|
||||||
system_message = SystemMessage(content="You are to chat with the user.")
|
system_message = SystemMessage(content="You are to chat with the user.")
|
||||||
human_message = HumanMessage(content="Hello")
|
human_message = HumanMessage(content="Hello")
|
||||||
response = chat.invoke([system_message, human_message])
|
response = chat.invoke([system_message, human_message])
|
||||||
@ -242,10 +248,9 @@ def test_system_message() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
|
||||||
def test_tool_choice() -> None:
|
def test_tool_choice() -> None:
|
||||||
"""Test that tool choice is respected."""
|
"""Test that tool choice is respected."""
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model=MODEL_NAME)
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
class MyTool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -273,10 +278,9 @@ def test_tool_choice() -> None:
|
|||||||
assert tool_call["args"] == {"name": "Erick", "age": 27}
|
assert tool_call["args"] == {"name": "Erick", "age": 27}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
|
||||||
def test_tool_choice_bool() -> None:
|
def test_tool_choice_bool() -> None:
|
||||||
"""Test that tool choice is respected just passing in True."""
|
"""Test that tool choice is respected just passing in True."""
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model=MODEL_NAME)
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
class MyTool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -301,7 +305,7 @@ def test_tool_choice_bool() -> None:
|
|||||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||||
def test_streaming_tool_call() -> None:
|
def test_streaming_tool_call() -> None:
|
||||||
"""Test that tool choice is respected."""
|
"""Test that tool choice is respected."""
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model=MODEL_NAME)
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
class MyTool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -339,7 +343,7 @@ def test_streaming_tool_call() -> None:
|
|||||||
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
|
||||||
async def test_astreaming_tool_call() -> None:
|
async def test_astreaming_tool_call() -> None:
|
||||||
"""Test that tool choice is respected."""
|
"""Test that tool choice is respected."""
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model=MODEL_NAME)
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
class MyTool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -384,7 +388,7 @@ def test_json_mode_structured_output() -> None:
|
|||||||
setup: str = Field(description="question to set up a joke")
|
setup: str = Field(description="question to set up a joke")
|
||||||
punchline: str = Field(description="answer to resolve the joke")
|
punchline: str = Field(description="answer to resolve the joke")
|
||||||
|
|
||||||
chat = ChatGroq().with_structured_output(Joke, method="json_mode") # type: ignore[call-arg]
|
chat = ChatGroq(model=MODEL_NAME).with_structured_output(Joke, method="json_mode")
|
||||||
result = chat.invoke(
|
result = chat.invoke(
|
||||||
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
||||||
)
|
)
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
}),
|
}),
|
||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'mixtral-8x7b-32768',
|
'model_name': 'llama-3.1-8b-instant',
|
||||||
'n': 1,
|
'n': 1,
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
@ -156,7 +157,7 @@ def mock_completion() -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def test_groq_invoke(mock_completion: dict) -> None:
|
def test_groq_invoke(mock_completion: dict) -> None:
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model="foo")
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
completed = False
|
completed = False
|
||||||
|
|
||||||
@ -178,7 +179,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_groq_ainvoke(mock_completion: dict) -> None:
|
async def test_groq_ainvoke(mock_completion: dict) -> None:
|
||||||
llm = ChatGroq() # type: ignore[call-arg]
|
llm = ChatGroq(model="foo")
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
completed = False
|
completed = False
|
||||||
|
|
||||||
@ -203,7 +204,7 @@ def test_chat_groq_extra_kwargs() -> None:
|
|||||||
"""Test extra kwargs to chat groq."""
|
"""Test extra kwargs to chat groq."""
|
||||||
# Check that foo is saved in extra_kwargs.
|
# Check that foo is saved in extra_kwargs.
|
||||||
with pytest.warns(UserWarning) as record:
|
with pytest.warns(UserWarning) as record:
|
||||||
llm = ChatGroq(foo=3, max_tokens=10) # type: ignore[call-arg]
|
llm = ChatGroq(model="foo", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||||
assert llm.max_tokens == 10
|
assert llm.max_tokens == 10
|
||||||
assert llm.model_kwargs == {"foo": 3}
|
assert llm.model_kwargs == {"foo": 3}
|
||||||
assert len(record) == 1
|
assert len(record) == 1
|
||||||
@ -212,7 +213,7 @@ def test_chat_groq_extra_kwargs() -> None:
|
|||||||
|
|
||||||
# Test that if extra_kwargs are provided, they are added to it.
|
# Test that if extra_kwargs are provided, they are added to it.
|
||||||
with pytest.warns(UserWarning) as record:
|
with pytest.warns(UserWarning) as record:
|
||||||
llm = ChatGroq(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
llm = ChatGroq(model="foo", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||||
assert len(record) == 1
|
assert len(record) == 1
|
||||||
assert type(record[0].message) is UserWarning
|
assert type(record[0].message) is UserWarning
|
||||||
@ -220,21 +221,22 @@ def test_chat_groq_extra_kwargs() -> None:
|
|||||||
|
|
||||||
# Test that if provided twice it errors
|
# Test that if provided twice it errors
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
ChatGroq(model="foo", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||||
|
|
||||||
# Test that if explicit param is specified in kwargs it errors
|
# Test that if explicit param is specified in kwargs it errors
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq(model_kwargs={"temperature": 0.2}) # type: ignore[call-arg]
|
ChatGroq(model="foo", model_kwargs={"temperature": 0.2})
|
||||||
|
|
||||||
# Test that "model" cannot be specified in kwargs
|
# Test that "model" cannot be specified in kwargs
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq(model_kwargs={"model": "test-model"}) # type: ignore[call-arg]
|
ChatGroq(model="foo", model_kwargs={"model": "test-model"})
|
||||||
|
|
||||||
|
|
||||||
def test_chat_groq_invalid_streaming_params() -> None:
|
def test_chat_groq_invalid_streaming_params() -> None:
|
||||||
"""Test that an error is raised if streaming is invoked with n>1."""
|
"""Test that an error is raised if streaming is invoked with n>1."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq( # type: ignore[call-arg]
|
ChatGroq(
|
||||||
|
model="foo",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
@ -246,7 +248,7 @@ def test_chat_groq_secret() -> None:
|
|||||||
"""Test that secret is not printed"""
|
"""Test that secret is not printed"""
|
||||||
secret = "secretKey"
|
secret = "secretKey"
|
||||||
not_secret = "safe"
|
not_secret = "safe"
|
||||||
llm = ChatGroq(api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
||||||
stringified = str(llm)
|
stringified = str(llm)
|
||||||
assert not_secret in stringified
|
assert not_secret in stringified
|
||||||
assert secret not in stringified
|
assert secret not in stringified
|
||||||
@ -257,7 +259,7 @@ def test_groq_serialization() -> None:
|
|||||||
"""Test that ChatGroq can be successfully serialized and deserialized"""
|
"""Test that ChatGroq can be successfully serialized and deserialized"""
|
||||||
api_key1 = "top secret"
|
api_key1 = "top secret"
|
||||||
api_key2 = "topest secret"
|
api_key2 = "topest secret"
|
||||||
llm = ChatGroq(api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
||||||
dump = lc_load.dumps(llm)
|
dump = lc_load.dumps(llm)
|
||||||
llm2 = lc_load.loads(
|
llm2 = lc_load.loads(
|
||||||
dump,
|
dump,
|
||||||
@ -278,3 +280,23 @@ def test_groq_serialization() -> None:
|
|||||||
|
|
||||||
# Ensure a None was preserved
|
# Ensure a None was preserved
|
||||||
assert llm.groq_api_base == llm2.groq_api_base
|
assert llm.groq_api_base == llm2.groq_api_base
|
||||||
|
|
||||||
|
|
||||||
|
def test_groq_warns_default_model() -> None:
|
||||||
|
"""Test that a warning is raised if a default model is used."""
|
||||||
|
|
||||||
|
# Delete this test in 0.3 release, when the default model is removed.
|
||||||
|
|
||||||
|
# Test no warning if model is specified
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
ChatGroq(model="foo")
|
||||||
|
|
||||||
|
# Test warns if default model is used
|
||||||
|
with pytest.warns(match="default model"):
|
||||||
|
ChatGroq()
|
||||||
|
|
||||||
|
# Test only warns once
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
ChatGroq()
|
||||||
|
@ -14,3 +14,7 @@ class TestGroqStandard(ChatModelUnitTests):
|
|||||||
@property
|
@property
|
||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
return ChatGroq
|
return ChatGroq
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_model_params(self) -> dict:
|
||||||
|
return {"model": "llama-3.1-8b-instant"}
|
||||||
|
Loading…
Reference in New Issue
Block a user