fireworks[patch]: Add Fireworks partner packages (#17694)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Yufei (Benny) Chen
2024-02-23 12:45:47 -08:00
committed by GitHub
parent 11cf95e810
commit ee6a773456
31 changed files with 4740 additions and 1550 deletions

1
libs/partners/fireworks/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/fireworks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_fireworks
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_fireworks -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@@ -0,0 +1,15 @@
# LangChain-Fireworks
This is the partner package for tying Fireworks.ai and LangChain. Fireworks really strive to provide good support for LangChain use cases, so if you run into any issues please let us know. You can reach out to us [in our Discord channel](https://discord.com/channels/1137072072808472616/)
## Basic LangChain-Fireworks example
## Advanced
### Tool use: LangChain Agent + Fireworks function calling model
Please checkout how to teach Fireworks function calling model to use a [calculator here](https://github.com/fw-ai/cookbook/blob/main/examples/function_calling/fireworks_langchain_tool_usage.ipynb).
Fireworks focus on delivering the best experience for fast model inference as well as tool use. You can check out [our blog](https://fireworks.ai/blog/firefunction-v1-gpt-4-level-function-calling) for more details on how it fares compares to GPT-4, the punchline is that it is on par with GPT-4 in terms just function calling use cases, but it is way faster and much cheaper.
### RAG: LangChain agent + Fireworks function calling model + MongoDB + Nomic AI embeddings
Please check out the [cookbook here](https://github.com/fw-ai/cookbook/blob/main/examples/rag/mongodb_agent.ipynb) for an end to end flow

View File

@@ -0,0 +1,11 @@
from langchain_fireworks.chat_models import ChatFireworks
from langchain_fireworks.embeddings import FireworksEmbeddings
from langchain_fireworks.llms import Fireworks
from langchain_fireworks.version import __version__
__all__ = [
"__version__",
"ChatFireworks",
"Fireworks",
"FireworksEmbeddings",
]

View File

@@ -0,0 +1,615 @@
"""Fireworks chat wrapper."""
from __future__ import annotations
import logging
import os
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
)
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
# Fix for azure
# Also Fireworks returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "")
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore
class _FunctionCall(TypedDict):
name: str
# This is basically a copy and replace for ChatOpenAI, except
# - I needed to gut out tiktoken and some of the token estimation logic
# (not sure how important it is)
# - Environment variable is different
# we should refactor into some OpenAI-like class in the future
class ChatFireworks(BaseChatModel):
"""`Fireworks` Chat large language models API.
To use, you should have the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create call
can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_fireworks.chat_models import ChatFireworks
fireworks = ChatFireworks(
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.fireworks_api_base:
attributes["fireworks_api_base"] = self.fireworks_api_base
return attributes
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
)
"""Model name to use."""
temperature: float = 0.0
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
"""Automatically inferred from env var `FIREWORKS_API_KEY` if not provided."""
fireworks_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
None."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv(
"FIREWORKS_API_BASE"
)
client_params = {
"api_key": (
values["fireworks_api_key"].get_secret_value()
if values["fireworks_api_key"]
else None
),
"base_url": values["fireworks_api_base"],
"timeout": values["request_timeout"],
}
if not values.get("client"):
values["client"] = Fireworks(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = AsyncFireworks(**client_params).chat.completions
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Fireworks API."""
params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.client.create(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await self.async_client.create(
messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await self.async_client.create(messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name, **self._default_params}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
"model": self.model_name,
**super()._get_invocation_params(stop=stop),
**self._default_params,
**kwargs,
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "fireworks-chat"
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model.
Assumes model is compatible with Fireworks function-calling API.
NOTE: Using bind_tools is recommended instead, as the `functions` and
`function_call` request parameters are officially marked as deprecated by
Fireworks.
Args:
functions: A list of function definitions to bind to this chat model.
Can be a dictionary, pydantic model, or callable. Pydantic
models and callables will be automatically converted to
their schema dictionary representation.
function_call: Which function to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any).
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None:
function_call = (
{"name": function_call}
if isinstance(function_call, str)
and function_call not in ("auto", "none")
else function_call
)
if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError(
"When specifying `function_call`, you must provide exactly one "
"function."
)
if (
isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"]
):
raise ValueError(
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
kwargs = {**kwargs, "function_call": function_call}
return super().bind(
functions=formatted_functions,
**kwargs,
)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"]]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with Fireworks tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None:
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
tool_choice = {"type": "function", "function": {"name": tool_choice}}
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, dict) and (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,52 @@
import os
from typing import Any, Dict, List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from openai import OpenAI # type: ignore
class FireworksEmbeddings(BaseModel, Embeddings):
"""FireworksEmbeddings embedding model.
Example:
.. code-block:: python
from langchain_fireworks import FireworksEmbeddings
model = FireworksEmbeddings(
model='nomic-ai/nomic-embed-text-v1.5'
)
"""
_client: OpenAI = Field(default=None)
fireworks_api_key: SecretStr = convert_to_secret_str("")
model: str = "nomic-ai/nomic-embed-text-v1.5"
@root_validator()
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate environment variables."""
fireworks_api_key = convert_to_secret_str(
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
)
values["fireworks_api_key"] = fireworks_api_key
# note this sets it globally for module
# there isn't currently a way to pass it into client
api_key = fireworks_api_key.get_secret_value()
values["_client"] = OpenAI(
api_key=api_key, base_url="https://api.fireworks.ai/inference/v1"
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data
]
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]

View File

@@ -0,0 +1,222 @@
"""Wrapper around Fireworks AI's Completion API."""
import logging
from typing import Any, Dict, List, Optional
import requests
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_fireworks.version import __version__
logger = logging.getLogger(__name__)
class Fireworks(LLM):
"""LLM models from `Fireworks`.
To use, you'll need an API key which you can find here:
https://fireworks.ai This can be passed in as init param
``fireworks_api_key`` or set as environment variable ``FIREWORKS_API_KEY``.
Fireworks AI API reference: https://readme.fireworks.ai/
Example:
.. code-block:: python
response = fireworks.generate(["Tell me a joke."])
"""
base_url: str = "https://api.fireworks.ai/inference/v1/completions"
"""Base inference API URL."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
model: str
"""Model name. Available models listed here:
https://readme.fireworks.ai/
"""
temperature: Optional[float] = None
"""Model temperature."""
top_p: Optional[float] = None
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of 1 will always yield the same
output. A temperature less than 1 favors more correctness and is appropriate
for question answering or summarization. A value greater than 1 introduces more
randomness in the output.
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation
process and can improve the quality of the generated text by focusing on the
most likely options.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
repetition_penalty: Optional[float] = None
"""A number that controls the diversity of generated text by reducing the
likelihood of repeated sequences. Higher values decrease repetition.
"""
logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in
the response for each token generation step.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
return values
@property
def _llm_type(self) -> str:
"""Return type of model."""
return "fireworks"
def _format_output(self, output: dict) -> str:
return output["choices"][0]["text"]
@staticmethod
def get_user_agent() -> str:
return f"langchain-fireworks/{__version__}"
@property
def default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Fireworks's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model..
"""
headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
response = requests.post(url=self.base_url, json=payload, headers=headers)
if response.status_code >= 500:
raise Exception(f"Fireworks Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"Fireworks returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
data = response.json()
output = self._format_output(data)
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Fireworks model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
"""
headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
async with ClientSession() as session:
async with session.post(
self.base_url, json=payload, headers=headers
) as response:
if response.status >= 500:
raise Exception(f"Fireworks Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"Fireworks received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"Fireworks returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
if response_json.get("status") != "finished":
err_msg = response_json.get("error", "Undefined Error")
raise Exception(err_msg)
output = self._format_output(response_json)
return output

View File

@@ -0,0 +1,8 @@
"""Main entrypoint into package."""
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""

1541
libs/partners/fireworks/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
[tool.poetry]
name = "langchain-fireworks"
version = "0.0.1"
description = "An integration package connecting Fireworks and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/fireworks"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.16"
fireworks-ai = ">=0.12.0,<1"
openai = "^1.10.0"
requests = "^2"
aiohttp = "^3.9.1"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.2.2"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = { path = "../../core", develop = true }
types-requests = "^2"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

View File

@@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@@ -0,0 +1,20 @@
"""Test Fireworks embeddings."""
from langchain_fireworks.embeddings import FireworksEmbeddings
def test_langchain_fireworks_embedding_documents() -> None:
"""Test Fireworks hosted embeddings."""
documents = ["foo bar"]
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) > 0
def test_langchain_fireworks_embedding_query() -> None:
"""Test Fireworks hosted embeddings."""
document = "foo bar"
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
output = embedding.embed_query(document)
assert len(output) > 0

View File

@@ -0,0 +1,41 @@
"""Test Fireworks API wrapper.
In order to run this test, you need to have an Fireworks api key.
You can get it by registering for free at https://api.fireworks.ai/.
A test key can be found at https://api.fireworks.ai/settings/api-keys
You'll then need to set FIREWORKS_API_KEY environment variable to your api key.
"""
import pytest as pytest
from langchain_fireworks import Fireworks
def test_fireworks_call() -> None:
"""Test simple call to fireworks."""
llm = Fireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
output = llm.invoke("Say foo:")
assert llm._llm_type == "fireworks"
assert isinstance(output, str)
assert len(output) > 0
async def test_fireworks_acall() -> None:
"""Test simple call to fireworks."""
llm = Fireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
output = await llm.agenerate(["Say foo:"], stop=["bar"])
assert llm._llm_type == "fireworks"
output_text = output.generations[0][0].text
assert isinstance(output_text, str)
assert output_text.count("bar") <= 1

View File

@@ -0,0 +1,8 @@
"""Test embedding model integration."""
from langchain_fireworks.embeddings import FireworksEmbeddings
def test_initialization() -> None:
"""Test embedding model initialization."""
FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")

View File

@@ -0,0 +1,12 @@
from langchain_fireworks import __all__
EXPECTED_ALL = [
"__version__",
"ChatFireworks",
"Fireworks",
"FireworksEmbeddings",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@@ -0,0 +1,62 @@
"""Test Fireworks LLM"""
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_fireworks import Fireworks
def test_fireworks_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
llm = Fireworks(
fireworks_api_key="secret-api-key",
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
assert isinstance(llm.fireworks_api_key, SecretStr)
def test_fireworks_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("FIREWORKS_API_KEY", "secret-api-key")
llm = Fireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
print(llm.fireworks_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_fireworks_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test that the API key is masked when passed via the constructor."""
llm = Fireworks(
fireworks_api_key="secret-api-key",
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
print(llm.fireworks_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Fireworks(
fireworks_api_key="secret-api-key",
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0.2,
max_tokens=250,
)
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"