Fireworks integration (#8322)

Description - Integrates Fireworks within Langchain LLMs to allow users
to use Fireworks models with Langchain, mainly for summarization.

Issue - Not applicable
Dependencies - None
Tag maintainer - @rlancemartin

---------

Co-authored-by: Raj Janardhan <rajjanardhan@Rajs-Laptop.attlocal.net>
This commit is contained in:
rjanardhan3
2023-08-01 21:17:26 -07:00
committed by GitHub
parent b574507c51
commit 68113348cc
5 changed files with 790 additions and 0 deletions

View File

@@ -39,6 +39,7 @@ from langchain.llms.ctransformers import CTransformers
from langchain.llms.databricks import Databricks
from langchain.llms.deepinfra import DeepInfra
from langchain.llms.fake import FakeListLLM
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI
@@ -98,6 +99,8 @@ __all__ = [
"Databricks",
"DeepInfra",
"FakeListLLM",
"Fireworks",
"FireworksChat",
"ForefrontAI",
"GPT4All",
"GooglePalm",

View File

@@ -0,0 +1,377 @@
"""Wrapper around Fireworks APIs"""
import json
import logging
from typing import (
Any,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
import requests
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class BaseFireworks(BaseLLM):
"""Wrapper around Fireworks large language models."""
model_id: str = Field("fireworks-llama-v2-7b-chat", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
"""Api key to use fireworks API"""
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
def __new__(cls, **data: Any) -> Any:
"""Initialize the Fireworks object."""
data.get("model_id", "")
return super().__new__(cls)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = completion_with_retry(self, prompt=prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_batch_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice,
)
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
return LLMResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
class FireworksChat(BaseLLM):
"""Wrapper around Fireworks Chat large language models.
To use, you should have the ``fireworksai`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import FireworksChat
fireworkschat = FireworksChat(model_id=""fireworks-llama-v2-13b-chat"")
"""
model_id: str = "fireworks-llama-v2-7b-chat"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
max_retries: int = 6
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
"""Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment"""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _get_chat_params(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple:
if len(prompts) > 1:
raise ValueError(
f"FireworksChat currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_id}}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
return messages, params
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
class Fireworks(BaseFireworks):
"""Wrapper around Fireworks large language models.
To use, you should have the ``fireworks`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import fireworks
llm = Fireworks(model_id="fireworks-llama-v2-13b")
"""
def update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
def execute(
prompt: str,
model: str,
api_key: Optional[str],
max_tokens: int = 256,
temperature: float = 0.0,
top_p: float = 1.0,
) -> Any:
"""Execute LLM query"""
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
requestBody = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
requestHeaders = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
return response.text
def completion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
async def acompletion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the async completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers

View File

@@ -0,0 +1,157 @@
"""Test Fireworks AI API Wrapper."""
from pathlib import Path
import pytest
from langchain import LLMChain, PromptTemplate
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAIChat
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.loading import load_llm
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import DeepLake
def test_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = Fireworks(model_id="fireworks-llama-v2-13b-chat", max_tokens=900)
output = llm("What is the weather in NYC")
assert isinstance(output, str)
def test_fireworks_in_chain() -> None:
"""Tests fireworks AI in a Langchain chain"""
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template="What is a good name for a company that makes {product}?",
input_variables=["product"],
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chat = Fireworks()
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
output = chain.run("football helmets")
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_openai_chat_async_generate() -> None:
"""Test async chat."""
llm = OpenAIChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
def test_fireworks_model_param() -> None:
"""Tests model parameters for Fireworks"""
llm = Fireworks(model="foo")
assert llm.model_id == "foo"
llm = Fireworks(model_id="foo")
assert llm.model_id == "foo"
def test_fireworkschat_model_param() -> None:
"""Tests model parameters for FireworksChat"""
llm = FireworksChat(model="foo")
assert llm.model_id == "foo"
llm = FireworksChat(model_id="foo")
assert llm.model_id == "foo"
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Fireworks LLM."""
llm = Fireworks(max_tokens=10)
llm.save(file_path=tmp_path / "fireworks.yaml")
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
assert loaded_llm == llm
def test_fireworks_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = Fireworks()
output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
def test_fireworks_chat() -> None:
"""Test FireworksChat."""
llm = FireworksChat()
output = llm("Name me 3 quick facts about the New England Patriots")
assert isinstance(output, str)
async def test_fireworks_agenerate() -> None:
llm = Fireworks()
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
async def test_fireworkschat_agenerate() -> None:
llm = FireworksChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 1
def test_fireworkschat_chain() -> None:
embeddings = OpenAIEmbeddings()
loader = TextLoader(
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
db = DeepLake(
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
)
db.add_documents(docs)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search(query)
qa = RetrievalQA.from_chain_type(
llm=FireworksChat(),
chain_type="stuff",
retriever=db.as_retriever(),
)
query = "What did the president say about Ketanji Brown Jackson"
output = qa.run(query)
assert isinstance(output, str)
_EXPECTED_NUM_TOKENS = {
"fireworks-llama-v2-13b": 17,
"fireworks-llama-v2-7b": 17,
"fireworks-llama-v2-13b-chat": 17,
"fireworks-llama-v2-7b-chat": 17,
}
_MODELS = models = [
"fireworks-llama-v2-13b",
"fireworks-llama-v2-7b",
"fireworks-llama-v2-13b-chat",
"fireworks-llama-v2-7b-chat",
]
@pytest.mark.parametrize("model", _MODELS)
def test_fireworks_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = Fireworks(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]