mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
Fireworks integration (#8322)
Description - Integrates Fireworks within Langchain LLMs to allow users to use Fireworks models with Langchain, mainly for summarization. Issue - Not applicable Dependencies - None Tag maintainer - @rlancemartin --------- Co-authored-by: Raj Janardhan <rajjanardhan@Rajs-Laptop.attlocal.net>
This commit is contained in:
@@ -39,6 +39,7 @@ from langchain.llms.ctransformers import CTransformers
|
||||
from langchain.llms.databricks import Databricks
|
||||
from langchain.llms.deepinfra import DeepInfra
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||
from langchain.llms.forefrontai import ForefrontAI
|
||||
from langchain.llms.google_palm import GooglePalm
|
||||
from langchain.llms.gooseai import GooseAI
|
||||
@@ -98,6 +99,8 @@ __all__ = [
|
||||
"Databricks",
|
||||
"DeepInfra",
|
||||
"FakeListLLM",
|
||||
"Fireworks",
|
||||
"FireworksChat",
|
||||
"ForefrontAI",
|
||||
"GPT4All",
|
||||
"GooglePalm",
|
||||
|
377
libs/langchain/langchain/llms/fireworks.py
Normal file
377
libs/langchain/langchain/llms/fireworks.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""Wrapper around Fireworks APIs"""
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFireworks(BaseLLM):
|
||||
"""Wrapper around Fireworks large language models."""
|
||||
|
||||
model_id: str = Field("fireworks-llama-v2-7b-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
fireworks_api_key: Optional[str] = None
|
||||
"""Api key to use fireworks API"""
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __new__(cls, **data: Any) -> Any:
|
||||
"""Initialize the Fireworks object."""
|
||||
data.get("model_id", "")
|
||||
return super().__new__(cls)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint with k unique prompts.
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The full LLM output.
|
||||
"""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = completion_with_retry(self, prompt=prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
def get_batch_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
sub_prompts = [
|
||||
prompts[i : i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||
) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i : (i + 1)]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice,
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks"
|
||||
|
||||
|
||||
class FireworksChat(BaseLLM):
|
||||
"""Wrapper around Fireworks Chat large language models.
|
||||
To use, you should have the ``fireworksai`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import FireworksChat
|
||||
fireworkschat = FireworksChat(model_id=""fireworks-llama-v2-13b-chat"")
|
||||
"""
|
||||
|
||||
model_id: str = "fireworks-llama-v2-7b-chat"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 6
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
"""Maximum number of retries to make when generating."""
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment"""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_chat_params(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> Tuple:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError(
|
||||
f"FireworksChat currently only supports single prompt, got {prompts}"
|
||||
)
|
||||
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
||||
params: Dict[str, Any] = {**{"model": self.model_id}}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
return messages, params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
|
||||
|
||||
class Fireworks(BaseFireworks):
|
||||
"""Wrapper around Fireworks large language models.
|
||||
To use, you should have the ``fireworks`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import fireworks
|
||||
llm = Fireworks(model_id="fireworks-llama-v2-13b")
|
||||
"""
|
||||
|
||||
|
||||
def update_token_usage(
|
||||
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Update token usage."""
|
||||
_keys_to_use = keys.intersection(response)
|
||||
for _key in _keys_to_use:
|
||||
if _key not in token_usage:
|
||||
token_usage[_key] = response["usage"][_key]
|
||||
else:
|
||||
token_usage[_key] += response["usage"][_key]
|
||||
|
||||
|
||||
def execute(
|
||||
prompt: str,
|
||||
model: str,
|
||||
api_key: Optional[str],
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
) -> Any:
|
||||
"""Execute LLM query"""
|
||||
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
|
||||
requestBody = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
requestHeaders = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
|
||||
return response.text
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
157
libs/langchain/tests/integration_tests/llms/test_fireworks.py
Normal file
157
libs/langchain/tests/integration_tests/llms/test_fireworks.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Test Fireworks AI API Wrapper."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import OpenAIChat
|
||||
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import DeepLake
|
||||
|
||||
|
||||
def test_fireworks_call() -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
llm = Fireworks(model_id="fireworks-llama-v2-13b-chat", max_tokens=900)
|
||||
output = llm("What is the weather in NYC")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_fireworks_in_chain() -> None:
|
||||
"""Tests fireworks AI in a Langchain chain"""
|
||||
human_message_prompt = HumanMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
input_variables=["product"],
|
||||
)
|
||||
)
|
||||
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
||||
chat = Fireworks()
|
||||
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
|
||||
output = chain.run("football helmets")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_async_generate() -> None:
|
||||
"""Test async chat."""
|
||||
llm = OpenAIChat(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_fireworks_model_param() -> None:
|
||||
"""Tests model parameters for Fireworks"""
|
||||
llm = Fireworks(model="foo")
|
||||
assert llm.model_id == "foo"
|
||||
llm = Fireworks(model_id="foo")
|
||||
assert llm.model_id == "foo"
|
||||
|
||||
|
||||
def test_fireworkschat_model_param() -> None:
|
||||
"""Tests model parameters for FireworksChat"""
|
||||
llm = FireworksChat(model="foo")
|
||||
assert llm.model_id == "foo"
|
||||
llm = FireworksChat(model_id="foo")
|
||||
assert llm.model_id == "foo"
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an Fireworks LLM."""
|
||||
llm = Fireworks(max_tokens=10)
|
||||
llm.save(file_path=tmp_path / "fireworks.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
|
||||
assert loaded_llm == llm
|
||||
|
||||
|
||||
def test_fireworks_multiple_prompts() -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
llm = Fireworks()
|
||||
output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_fireworks_chat() -> None:
|
||||
"""Test FireworksChat."""
|
||||
llm = FireworksChat()
|
||||
output = llm("Name me 3 quick facts about the New England Patriots")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_fireworks_agenerate() -> None:
|
||||
llm = Fireworks()
|
||||
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
async def test_fireworkschat_agenerate() -> None:
|
||||
llm = FireworksChat(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 1
|
||||
|
||||
|
||||
def test_fireworkschat_chain() -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
loader = TextLoader(
|
||||
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
|
||||
)
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
db = DeepLake(
|
||||
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
|
||||
)
|
||||
db.add_documents(docs)
|
||||
|
||||
query = "What did the president say about Ketanji Brown Jackson"
|
||||
docs = db.similarity_search(query)
|
||||
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=FireworksChat(),
|
||||
chain_type="stuff",
|
||||
retriever=db.as_retriever(),
|
||||
)
|
||||
query = "What did the president say about Ketanji Brown Jackson"
|
||||
output = qa.run(query)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
_EXPECTED_NUM_TOKENS = {
|
||||
"fireworks-llama-v2-13b": 17,
|
||||
"fireworks-llama-v2-7b": 17,
|
||||
"fireworks-llama-v2-13b-chat": 17,
|
||||
"fireworks-llama-v2-7b-chat": 17,
|
||||
}
|
||||
|
||||
_MODELS = models = [
|
||||
"fireworks-llama-v2-13b",
|
||||
"fireworks-llama-v2-7b",
|
||||
"fireworks-llama-v2-13b-chat",
|
||||
"fireworks-llama-v2-7b-chat",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _MODELS)
|
||||
def test_fireworks_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = Fireworks(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
Reference in New Issue
Block a user