mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 20:05:58 +00:00
fireworks[minor]: remove default model and temperature (#30965)
`mixtral-8x-7b-instruct` was recently retired from Fireworks Serverless. Here we remove the default model altogether, so that the model must be explicitly specified on init: ```python ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct") # for example ``` We also set a null default for `temperature`, which previously defaulted to 0.0. This parameter will no longer be included in request payloads unless it is explicitly provided.
This commit is contained in:
parent
4be55f7c89
commit
eedda164c6
@ -90,7 +90,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"id": "d285fd7f",
|
"id": "d285fd7f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -99,7 +99,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Initialize a Fireworks model\n",
|
"# Initialize a Fireworks model\n",
|
||||||
"llm = Fireworks(\n",
|
"llm = Fireworks(\n",
|
||||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||||
" base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
|
" base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
@ -176,7 +176,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"id": "b801c20d",
|
"id": "b801c20d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -192,7 +192,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||||
"llm = Fireworks(\n",
|
"llm = Fireworks(\n",
|
||||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||||
" temperature=0.7,\n",
|
" temperature=0.7,\n",
|
||||||
" max_tokens=15,\n",
|
" max_tokens=15,\n",
|
||||||
" top_p=1.0,\n",
|
" top_p=1.0,\n",
|
||||||
@ -218,7 +218,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"id": "fd2c6bc1",
|
"id": "fd2c6bc1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -235,7 +235,7 @@
|
|||||||
"from langchain_fireworks import Fireworks\n",
|
"from langchain_fireworks import Fireworks\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = Fireworks(\n",
|
"llm = Fireworks(\n",
|
||||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
|
" model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
|
||||||
" temperature=0.7,\n",
|
" temperature=0.7,\n",
|
||||||
" max_tokens=15,\n",
|
" max_tokens=15,\n",
|
||||||
" top_p=1.0,\n",
|
" top_p=1.0,\n",
|
||||||
|
@ -39,7 +39,7 @@ import os
|
|||||||
|
|
||||||
# Initialize a Fireworks model
|
# Initialize a Fireworks model
|
||||||
llm = Fireworks(
|
llm = Fireworks(
|
||||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
base_url="https://api.fireworks.ai/inference/v1/completions",
|
base_url="https://api.fireworks.ai/inference/v1/completions",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
@ -279,7 +279,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
|
|
||||||
from langchain_fireworks.chat_models import ChatFireworks
|
from langchain_fireworks.chat_models import ChatFireworks
|
||||||
fireworks = ChatFireworks(
|
fireworks = ChatFireworks(
|
||||||
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
|
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -306,11 +306,9 @@ class ChatFireworks(BaseChatModel):
|
|||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(
|
model_name: str = Field(alias="model")
|
||||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
|
||||||
)
|
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: float = 0.0
|
temperature: Optional[float] = None
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
||||||
"""Default stop sequences."""
|
"""Default stop sequences."""
|
||||||
@ -397,10 +395,11 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
"n": self.n,
|
"n": self.n,
|
||||||
"temperature": self.temperature,
|
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
|
if self.temperature is not None:
|
||||||
|
params["temperature"] = self.temperature
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
params["max_tokens"] = self.max_tokens
|
params["max_tokens"] = self.max_tokens
|
||||||
return params
|
return params
|
||||||
|
@ -13,54 +13,13 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from langchain_fireworks import ChatFireworks
|
from langchain_fireworks import ChatFireworks
|
||||||
|
|
||||||
|
_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
||||||
def test_chat_fireworks_call() -> None:
|
|
||||||
"""Test valid call to fireworks."""
|
|
||||||
llm = ChatFireworks( # type: ignore[call-arg]
|
|
||||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = llm.invoke("Hello!")
|
|
||||||
assert isinstance(resp, AIMessage)
|
|
||||||
|
|
||||||
assert len(resp.content) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice() -> None:
|
|
||||||
"""Test that tool choice is respected."""
|
|
||||||
llm = ChatFireworks( # type: ignore[call-arg]
|
|
||||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
|
||||||
)
|
|
||||||
|
|
||||||
class MyTool(BaseModel):
|
|
||||||
name: str
|
|
||||||
age: int
|
|
||||||
|
|
||||||
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
|
|
||||||
|
|
||||||
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
|
||||||
assert isinstance(resp, AIMessage)
|
|
||||||
assert resp.content == "" # should just be tool call
|
|
||||||
tool_calls = resp.additional_kwargs["tool_calls"]
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
tool_call = tool_calls[0]
|
|
||||||
assert tool_call["function"]["name"] == "MyTool"
|
|
||||||
assert json.loads(tool_call["function"]["arguments"]) == {
|
|
||||||
"age": 27,
|
|
||||||
"name": "Erick",
|
|
||||||
}
|
|
||||||
assert tool_call["type"] == "function"
|
|
||||||
assert isinstance(resp.tool_calls, list)
|
|
||||||
assert len(resp.tool_calls) == 1
|
|
||||||
tool_call = resp.tool_calls[0]
|
|
||||||
assert tool_call["name"] == "MyTool"
|
|
||||||
assert tool_call["args"] == {"age": 27, "name": "Erick"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice_bool() -> None:
|
def test_tool_choice_bool() -> None:
|
||||||
"""Test that tool choice is respected just passing in True."""
|
"""Test that tool choice is respected just passing in True."""
|
||||||
|
|
||||||
llm = ChatFireworks( # type: ignore[call-arg]
|
llm = ChatFireworks(
|
||||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,17 +43,9 @@ def test_tool_choice_bool() -> None:
|
|||||||
assert tool_call["type"] == "function"
|
assert tool_call["type"] == "function"
|
||||||
|
|
||||||
|
|
||||||
def test_stream() -> None:
|
|
||||||
"""Test streaming tokens from ChatFireworks."""
|
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
|
||||||
|
|
||||||
for token in llm.stream("I'm Pickle Rick"):
|
|
||||||
assert isinstance(token.content, str)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
"""Test streaming tokens from ChatFireworks."""
|
"""Test streaming tokens from ChatFireworks."""
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
llm = ChatFireworks(model=_MODEL)
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
full: Optional[BaseMessageChunk] = None
|
||||||
chunks_with_token_counts = 0
|
chunks_with_token_counts = 0
|
||||||
@ -125,18 +76,9 @@ async def test_astream() -> None:
|
|||||||
assert full.response_metadata["model_name"]
|
assert full.response_metadata["model_name"]
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
|
||||||
"""Test abatch tokens from ChatFireworks."""
|
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
|
||||||
|
|
||||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
|
||||||
for token in result:
|
|
||||||
assert isinstance(token.content, str)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch_tags() -> None:
|
async def test_abatch_tags() -> None:
|
||||||
"""Test batch tokens from ChatFireworks."""
|
"""Test batch tokens from ChatFireworks."""
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
llm = ChatFireworks(model=_MODEL)
|
||||||
|
|
||||||
result = await llm.abatch(
|
result = await llm.abatch(
|
||||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||||
@ -145,18 +87,9 @@ async def test_abatch_tags() -> None:
|
|||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_batch() -> None:
|
|
||||||
"""Test batch tokens from ChatFireworks."""
|
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
|
||||||
|
|
||||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
|
||||||
for token in result:
|
|
||||||
assert isinstance(token.content, str)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_ainvoke() -> None:
|
async def test_ainvoke() -> None:
|
||||||
"""Test invoke tokens from ChatFireworks."""
|
"""Test invoke tokens from ChatFireworks."""
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
llm = ChatFireworks(model=_MODEL)
|
||||||
|
|
||||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
@ -164,7 +97,7 @@ async def test_ainvoke() -> None:
|
|||||||
|
|
||||||
def test_invoke() -> None:
|
def test_invoke() -> None:
|
||||||
"""Test invoke tokens from ChatFireworks."""
|
"""Test invoke tokens from ChatFireworks."""
|
||||||
llm = ChatFireworks() # type: ignore[call-arg]
|
llm = ChatFireworks(model=_MODEL)
|
||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
@ -17,11 +17,12 @@
|
|||||||
}),
|
}),
|
||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'accounts/fireworks/models/mixtral-8x7b-instruct',
|
'model_name': 'accounts/fireworks/models/llama-v3p1-70b-instruct',
|
||||||
'n': 1,
|
'n': 1,
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
|
'temperature': 0.0,
|
||||||
}),
|
}),
|
||||||
'lc': 1,
|
'lc': 1,
|
||||||
'name': 'ChatFireworks',
|
'name': 'ChatFireworks',
|
||||||
|
@ -15,7 +15,10 @@ class TestFireworksStandard(ChatModelUnitTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"api_key": "test_api_key"}
|
return {
|
||||||
|
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||||
|
"api_key": "test_api_key",
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||||
@ -24,7 +27,9 @@ class TestFireworksStandard(ChatModelUnitTests):
|
|||||||
"FIREWORKS_API_KEY": "api_key",
|
"FIREWORKS_API_KEY": "api_key",
|
||||||
"FIREWORKS_API_BASE": "https://base.com",
|
"FIREWORKS_API_BASE": "https://base.com",
|
||||||
},
|
},
|
||||||
{},
|
{
|
||||||
|
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"fireworks_api_key": "api_key",
|
"fireworks_api_key": "api_key",
|
||||||
"fireworks_api_base": "https://base.com",
|
"fireworks_api_base": "https://base.com",
|
||||||
|
Loading…
Reference in New Issue
Block a user