mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
fireworks[minor]: remove default model and temperature (#30965)
`mixtral-8x-7b-instruct` was recently retired from Fireworks Serverless. Here we remove the default model altogether, so that the model must be explicitly specified on init: ```python ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct") # for example ``` We also set a null default for `temperature`, which previously defaulted to 0.0. This parameter will no longer be included in request payloads unless it is explicitly provided.
This commit is contained in:
parent
4be55f7c89
commit
eedda164c6
@ -90,7 +90,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "d285fd7f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -99,7 +99,7 @@
|
||||
"\n",
|
||||
"# Initialize a Fireworks model\n",
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||
" base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
|
||||
")"
|
||||
]
|
||||
@ -176,7 +176,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "b801c20d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -192,7 +192,7 @@
|
||||
"source": [
|
||||
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=15,\n",
|
||||
" top_p=1.0,\n",
|
||||
@ -218,7 +218,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "fd2c6bc1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -235,7 +235,7 @@
|
||||
"from langchain_fireworks import Fireworks\n",
|
||||
"\n",
|
||||
"llm = Fireworks(\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
||||
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=15,\n",
|
||||
" top_p=1.0,\n",
|
||||
|
@ -39,7 +39,7 @@ import os
|
||||
|
||||
# Initialize a Fireworks model
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
base_url="https://api.fireworks.ai/inference/v1/completions",
|
||||
)
|
||||
```
|
||||
|
@ -279,7 +279,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
fireworks = ChatFireworks(
|
||||
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
|
||||
"""
|
||||
|
||||
@property
|
||||
@ -306,11 +306,9 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(
|
||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
||||
)
|
||||
model_name: str = Field(alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
@ -397,10 +395,11 @@ class ChatFireworks(BaseChatModel):
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"stop": self.stop,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
@ -13,54 +13,13 @@ from typing_extensions import TypedDict
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
||||
def test_chat_fireworks_call() -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
llm = ChatFireworks( # type: ignore[call-arg]
|
||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
||||
)
|
||||
|
||||
resp = llm.invoke("Hello!")
|
||||
assert isinstance(resp, AIMessage)
|
||||
|
||||
assert len(resp.content) > 0
|
||||
|
||||
|
||||
def test_tool_choice() -> None:
|
||||
"""Test that tool choice is respected."""
|
||||
llm = ChatFireworks( # type: ignore[call-arg]
|
||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
||||
)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
|
||||
|
||||
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
||||
assert isinstance(resp, AIMessage)
|
||||
assert resp.content == "" # should just be tool call
|
||||
tool_calls = resp.additional_kwargs["tool_calls"]
|
||||
assert len(tool_calls) == 1
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == "MyTool"
|
||||
assert json.loads(tool_call["function"]["arguments"]) == {
|
||||
"age": 27,
|
||||
"name": "Erick",
|
||||
}
|
||||
assert tool_call["type"] == "function"
|
||||
assert isinstance(resp.tool_calls, list)
|
||||
assert len(resp.tool_calls) == 1
|
||||
tool_call = resp.tool_calls[0]
|
||||
assert tool_call["name"] == "MyTool"
|
||||
assert tool_call["args"] == {"age": 27, "name": "Erick"}
|
||||
_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
||||
|
||||
|
||||
def test_tool_choice_bool() -> None:
|
||||
"""Test that tool choice is respected just passing in True."""
|
||||
|
||||
llm = ChatFireworks( # type: ignore[call-arg]
|
||||
llm = ChatFireworks(
|
||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
||||
)
|
||||
|
||||
@ -84,17 +43,9 @@ def test_tool_choice_bool() -> None:
|
||||
assert tool_call["type"] == "function"
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
llm = ChatFireworks(model=_MODEL)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
@ -125,18 +76,9 @@ async def test_astream() -> None:
|
||||
assert full.response_metadata["model_name"]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test abatch tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
llm = ChatFireworks(model=_MODEL)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
@ -145,18 +87,9 @@ async def test_abatch_tags() -> None:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
llm = ChatFireworks(model=_MODEL)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
@ -164,7 +97,7 @@ async def test_ainvoke() -> None:
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatFireworks."""
|
||||
llm = ChatFireworks() # type: ignore[call-arg]
|
||||
llm = ChatFireworks(model=_MODEL)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
@ -17,11 +17,12 @@
|
||||
}),
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'accounts/fireworks/models/mixtral-8x7b-instruct',
|
||||
'model_name': 'accounts/fireworks/models/llama-v3p1-70b-instruct',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
'temperature': 0.0,
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ChatFireworks',
|
||||
|
@ -15,7 +15,10 @@ class TestFireworksStandard(ChatModelUnitTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"api_key": "test_api_key"}
|
||||
return {
|
||||
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||
@ -24,7 +27,9 @@ class TestFireworksStandard(ChatModelUnitTests):
|
||||
"FIREWORKS_API_KEY": "api_key",
|
||||
"FIREWORKS_API_BASE": "https://base.com",
|
||||
},
|
||||
{},
|
||||
{
|
||||
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
},
|
||||
{
|
||||
"fireworks_api_key": "api_key",
|
||||
"fireworks_api_base": "https://base.com",
|
||||
|
Loading…
Reference in New Issue
Block a user