fireworks[minor]: remove default model and temperature (#30965)

`mixtral-8x-7b-instruct` was recently retired from Fireworks Serverless.

Here we remove the default model altogether, so that the model must be
explicitly specified on init:
```python
ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")  # for example
```

We also set a null default for `temperature`, which previously defaulted
to 0.0. This parameter will no longer be included in request payloads
unless it is explicitly provided.
This commit is contained in:
ccurme 2025-04-22 15:58:58 -04:00 committed by GitHub
parent 4be55f7c89
commit eedda164c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 27 additions and 89 deletions

View File

@ -90,7 +90,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "d285fd7f",
"metadata": {},
"outputs": [],
@ -99,7 +99,7 @@
"\n",
"# Initialize a Fireworks model\n",
"llm = Fireworks(\n",
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
" base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
")"
]
@ -176,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "b801c20d",
"metadata": {},
"outputs": [
@ -192,7 +192,7 @@
"source": [
"# Setting additional parameters: temperature, max_tokens, top_p\n",
"llm = Fireworks(\n",
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
" temperature=0.7,\n",
" max_tokens=15,\n",
" top_p=1.0,\n",
@ -218,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "fd2c6bc1",
"metadata": {},
"outputs": [
@ -235,7 +235,7 @@
"from langchain_fireworks import Fireworks\n",
"\n",
"llm = Fireworks(\n",
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
" temperature=0.7,\n",
" max_tokens=15,\n",
" top_p=1.0,\n",

View File

@ -39,7 +39,7 @@ import os
# Initialize a Fireworks model
llm = Fireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
base_url="https://api.fireworks.ai/inference/v1/completions",
)
```

View File

@ -279,7 +279,7 @@ class ChatFireworks(BaseChatModel):
from langchain_fireworks.chat_models import ChatFireworks
fireworks = ChatFireworks(
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
"""
@property
@ -306,11 +306,9 @@ class ChatFireworks(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
)
model_name: str = Field(alias="model")
"""Model name to use."""
temperature: float = 0.0
temperature: Optional[float] = None
"""What sampling temperature to use."""
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
@ -397,10 +395,11 @@ class ChatFireworks(BaseChatModel):
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
"stop": self.stop,
**self.model_kwargs,
}
if self.temperature is not None:
params["temperature"] = self.temperature
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params

View File

@ -13,54 +13,13 @@ from typing_extensions import TypedDict
from langchain_fireworks import ChatFireworks
def test_chat_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = ChatFireworks( # type: ignore[call-arg]
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
)
resp = llm.invoke("Hello!")
assert isinstance(resp, AIMessage)
assert len(resp.content) > 0
def test_tool_choice() -> None:
"""Test that tool choice is respected."""
llm = ChatFireworks( # type: ignore[call-arg]
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
)
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
resp = with_tool.invoke("Who was the 27 year old named Erick?")
assert isinstance(resp, AIMessage)
assert resp.content == "" # should just be tool call
tool_calls = resp.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
assert isinstance(resp.tool_calls, list)
assert len(resp.tool_calls) == 1
tool_call = resp.tool_calls[0]
assert tool_call["name"] == "MyTool"
assert tool_call["args"] == {"age": 27, "name": "Erick"}
_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
def test_tool_choice_bool() -> None:
"""Test that tool choice is respected just passing in True."""
llm = ChatFireworks( # type: ignore[call-arg]
llm = ChatFireworks(
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
)
@ -84,17 +43,9 @@ def test_tool_choice_bool() -> None:
assert tool_call["type"] == "function"
def test_stream() -> None:
"""Test streaming tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
llm = ChatFireworks(model=_MODEL)
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
@ -125,18 +76,9 @@ async def test_astream() -> None:
assert full.response_metadata["model_name"]
async def test_abatch() -> None:
"""Test abatch tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
llm = ChatFireworks(model=_MODEL)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
@ -145,18 +87,9 @@ async def test_abatch_tags() -> None:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
llm = ChatFireworks(model=_MODEL)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
@ -164,7 +97,7 @@ async def test_ainvoke() -> None:
def test_invoke() -> None:
"""Test invoke tokens from ChatFireworks."""
llm = ChatFireworks() # type: ignore[call-arg]
llm = ChatFireworks(model=_MODEL)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

View File

@ -17,11 +17,12 @@
}),
'max_retries': 2,
'max_tokens': 100,
'model_name': 'accounts/fireworks/models/mixtral-8x7b-instruct',
'model_name': 'accounts/fireworks/models/llama-v3p1-70b-instruct',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
'temperature': 0.0,
}),
'lc': 1,
'name': 'ChatFireworks',

View File

@ -15,7 +15,10 @@ class TestFireworksStandard(ChatModelUnitTests):
@property
def chat_model_params(self) -> dict:
return {"api_key": "test_api_key"}
return {
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"api_key": "test_api_key",
}
@property
def init_from_env_params(self) -> tuple[dict, dict, dict]:
@ -24,7 +27,9 @@ class TestFireworksStandard(ChatModelUnitTests):
"FIREWORKS_API_KEY": "api_key",
"FIREWORKS_API_BASE": "https://base.com",
},
{},
{
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
},
{
"fireworks_api_key": "api_key",
"fireworks_api_base": "https://base.com",