diff --git a/docs/docs/integrations/llms/fireworks.ipynb b/docs/docs/integrations/llms/fireworks.ipynb index 3dd7fc4483c..2a558b29306 100644 --- a/docs/docs/integrations/llms/fireworks.ipynb +++ b/docs/docs/integrations/llms/fireworks.ipynb @@ -90,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "d285fd7f", "metadata": {}, "outputs": [], @@ -99,7 +99,7 @@ "\n", "# Initialize a Fireworks model\n", "llm = Fireworks(\n", - " model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n", + " model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n", " base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n", ")" ] @@ -176,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "b801c20d", "metadata": {}, "outputs": [ @@ -192,7 +192,7 @@ "source": [ "# Setting additional parameters: temperature, max_tokens, top_p\n", "llm = Fireworks(\n", - " model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n", + " model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n", " temperature=0.7,\n", " max_tokens=15,\n", " top_p=1.0,\n", @@ -218,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "fd2c6bc1", "metadata": {}, "outputs": [ @@ -235,7 +235,7 @@ "from langchain_fireworks import Fireworks\n", "\n", "llm = Fireworks(\n", - " model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n", + " model=\"accounts/fireworks/models/llama-v3p1-8b-instruct\",\n", " temperature=0.7,\n", " max_tokens=15,\n", " top_p=1.0,\n", diff --git a/libs/partners/fireworks/README.md b/libs/partners/fireworks/README.md index ccb55173a35..db634570637 100644 --- a/libs/partners/fireworks/README.md +++ b/libs/partners/fireworks/README.md @@ -39,7 +39,7 @@ import os # Initialize a Fireworks model llm = Fireworks( - model="accounts/fireworks/models/mixtral-8x7b-instruct", + model="accounts/fireworks/models/llama-v3p1-8b-instruct", base_url="https://api.fireworks.ai/inference/v1/completions", ) ``` diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 70dd14684d4..e5c9cfee399 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -279,7 +279,7 @@ class ChatFireworks(BaseChatModel): from langchain_fireworks.chat_models import ChatFireworks fireworks = ChatFireworks( - model_name="accounts/fireworks/models/mixtral-8x7b-instruct") + model_name="accounts/fireworks/models/llama-v3p1-8b-instruct") """ @property @@ -306,11 +306,9 @@ class ChatFireworks(BaseChatModel): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: - model_name: str = Field( - default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model" - ) + model_name: str = Field(alias="model") """Model name to use.""" - temperature: float = 0.0 + temperature: Optional[float] = None """What sampling temperature to use.""" stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" @@ -397,10 +395,11 @@ class ChatFireworks(BaseChatModel): "model": self.model_name, "stream": self.streaming, "n": self.n, - "temperature": self.temperature, "stop": self.stop, **self.model_kwargs, } + if self.temperature is not None: + params["temperature"] = self.temperature if self.max_tokens is not None: params["max_tokens"] = self.max_tokens return params diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 0209aa9b15d..8de13e1f30f 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -13,54 +13,13 @@ from typing_extensions import TypedDict from langchain_fireworks import ChatFireworks - -def test_chat_fireworks_call() -> None: - """Test valid call to fireworks.""" - llm = ChatFireworks( # type: ignore[call-arg] - model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0 - ) - - resp = llm.invoke("Hello!") - assert isinstance(resp, AIMessage) - - assert len(resp.content) > 0 - - -def test_tool_choice() -> None: - """Test that tool choice is respected.""" - llm = ChatFireworks( # type: ignore[call-arg] - model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0 - ) - - class MyTool(BaseModel): - name: str - age: int - - with_tool = llm.bind_tools([MyTool], tool_choice="MyTool") - - resp = with_tool.invoke("Who was the 27 year old named Erick?") - assert isinstance(resp, AIMessage) - assert resp.content == "" # should just be tool call - tool_calls = resp.additional_kwargs["tool_calls"] - assert len(tool_calls) == 1 - tool_call = tool_calls[0] - assert tool_call["function"]["name"] == "MyTool" - assert json.loads(tool_call["function"]["arguments"]) == { - "age": 27, - "name": "Erick", - } - assert tool_call["type"] == "function" - assert isinstance(resp.tool_calls, list) - assert len(resp.tool_calls) == 1 - tool_call = resp.tool_calls[0] - assert tool_call["name"] == "MyTool" - assert tool_call["args"] == {"age": 27, "name": "Erick"} +_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct" def test_tool_choice_bool() -> None: """Test that tool choice is respected just passing in True.""" - llm = ChatFireworks( # type: ignore[call-arg] + llm = ChatFireworks( model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0 ) @@ -84,17 +43,9 @@ def test_tool_choice_bool() -> None: assert tool_call["type"] == "function" -def test_stream() -> None: - """Test streaming tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - async def test_astream() -> None: """Test streaming tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] + llm = ChatFireworks(model=_MODEL) full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 @@ -125,18 +76,9 @@ async def test_astream() -> None: assert full.response_metadata["model_name"] -async def test_abatch() -> None: - """Test abatch tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - async def test_abatch_tags() -> None: """Test batch tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] + llm = ChatFireworks(model=_MODEL) result = await llm.abatch( ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} @@ -145,18 +87,9 @@ async def test_abatch_tags() -> None: assert isinstance(token.content, str) -def test_batch() -> None: - """Test batch tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - async def test_ainvoke() -> None: """Test invoke tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] + llm = ChatFireworks(model=_MODEL) result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) @@ -164,7 +97,7 @@ async def test_ainvoke() -> None: def test_invoke() -> None: """Test invoke tokens from ChatFireworks.""" - llm = ChatFireworks() # type: ignore[call-arg] + llm = ChatFireworks(model=_MODEL) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) diff --git a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr index 4375bf55ff0..99ef0e8109a 100644 --- a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr @@ -17,11 +17,12 @@ }), 'max_retries': 2, 'max_tokens': 100, - 'model_name': 'accounts/fireworks/models/mixtral-8x7b-instruct', + 'model_name': 'accounts/fireworks/models/llama-v3p1-70b-instruct', 'n': 1, 'request_timeout': 60.0, 'stop': list([ ]), + 'temperature': 0.0, }), 'lc': 1, 'name': 'ChatFireworks', diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 25cf0b6f3e4..3aee0335557 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -15,7 +15,10 @@ class TestFireworksStandard(ChatModelUnitTests): @property def chat_model_params(self) -> dict: - return {"api_key": "test_api_key"} + return { + "model": "accounts/fireworks/models/llama-v3p1-70b-instruct", + "api_key": "test_api_key", + } @property def init_from_env_params(self) -> tuple[dict, dict, dict]: @@ -24,7 +27,9 @@ class TestFireworksStandard(ChatModelUnitTests): "FIREWORKS_API_KEY": "api_key", "FIREWORKS_API_BASE": "https://base.com", }, - {}, + { + "model": "accounts/fireworks/models/llama-v3p1-70b-instruct", + }, { "fireworks_api_key": "api_key", "fireworks_api_base": "https://base.com",