Refactor Fireworks and add ChatFireworks (#3) (#10597)

Description 
* Refactor Fireworks within Langchain LLMs.
* Remove FireworksChat within Langchain LLMs.
* Add ChatFireworks (which uses chat completion api) to Langchain chat
models.
* Users have to install `fireworks-ai` and register an api key to use
the api.

Issue - Not applicable
Dependencies - None
Tag maintainer - @rlancemartin @baskaryan
This commit is contained in:
Cynthia Yang
2023-09-26 20:11:55 -07:00
committed by GitHub
parent 5514ebe859
commit 6dd44ff1c0
10 changed files with 1517 additions and 900 deletions

View File

@@ -0,0 +1,264 @@
import fireworks
import fireworks.client
from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import create_base_retry_decorator
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
"""Convert a delta response to a message chunk."""
role = _dict.role
content = _dict.content or ""
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict.name)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dict response to a message."""
role = _dict.role
content = _dict.content or ""
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
content = _dict.content
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=content)
elif role == "function":
return FunctionMessage(content=content, name=_dict.name)
else:
return ChatMessage(content=content, role=role)
class ChatFireworks(BaseChatModel):
"""Fireworks Chat models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
fireworks_api_key: Optional[str] = None
max_retries: int = 20
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
fireworks.client.api_key = fireworks_api_key
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages, stop)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = completion_with_retry(self, **params)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages, stop)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = await acompletion_with_retry(self, **params)
return self._create_chat_result(response)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return llm_outputs[0]
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(res.message)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
)
generations.append(gen)
llm_output = {"model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]]]:
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages, stop)
default_chunk_class = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
}
for chunk in completion_with_retry(self, **params):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages, stop)
default_chunk_class = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
}
async for chunk in await acompletion_with_retry_streaming(self, **params):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
def completion_with_retry(
llm: ChatFireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: ChatFireworks,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
llm: ChatFireworks,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call for streaming."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: ChatFireworks,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)

View File

@@ -44,7 +44,7 @@ from langchain.llms.deepinfra import DeepInfra
from langchain.llms.deepsparse import DeepSparse
from langchain.llms.edenai import EdenAI
from langchain.llms.fake import FakeListLLM
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.fireworks import Fireworks
from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI

View File

@@ -1,377 +1,220 @@
"""Wrapper around Fireworks APIs"""
import json
import logging
from typing import (
Any,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
import requests
import fireworks
import fireworks.client
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable.config import RunnableConfig
from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator
class BaseFireworks(BaseLLM):
"""Wrapper around Fireworks large language models."""
def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
return GenerationChunk(
text=stream_response.choices[0].text,
generation_info=dict(
finish_reason=stream_response.choices[0].finish_reason,
logprobs=stream_response.choices[0].logprobs,
),
)
model_id: str = Field("accounts/fireworks/models/llama-v2-7b-chat", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
class Fireworks(LLM):
"""Fireworks models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
fireworks_api_key: Optional[str] = None
"""Api key to use fireworks API"""
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def __new__(cls, **data: Any) -> Any:
"""Initialize the Fireworks object."""
data.get("model_id", "")
return super().__new__(cls)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
max_retries: int = 20
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["fireworks_api_key"] = get_from_dict_or_env(
"""Validate that api key in environment."""
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
fireworks.client.api_key = fireworks_api_key
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = completion_with_retry(self, prompt=prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_batch_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice,
)
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
return LLMResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
class FireworksChat(BaseLLM):
"""Wrapper around Fireworks Chat large language models.
To use, you should have the ``fireworksai`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import FireworksChat
fireworkschat = FireworksChat(model_id=""llama-v2-13b-chat"")
"""
model_id: str = "accounts/fireworks/models/llama-v2-7b-chat"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
max_retries: int = 6
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
"""Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment"""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _get_chat_params(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple:
if len(prompts) > 1:
raise ValueError(
f"FireworksChat currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_id}}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
return messages, params
def _generate(
def _call(
self,
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
) -> str:
"""Run the LLM on the given prompt and input."""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
response = completion_with_retry(self, **params)
async def _agenerate(
return response.choices[0].text
async def _acall(
self,
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
) -> str:
"""Run the LLM on the given prompt and input."""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
response = await acompletion_with_retry(self, **params)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
return response.choices[0].text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
for stream_resp in completion_with_retry(self, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
class Fireworks(BaseFireworks):
"""Wrapper around Fireworks large language models.
To use, you should have the ``fireworks`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import fireworks
llm = Fireworks(model_id="llama-v2-13b")
"""
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
async for stream_resp in await acompletion_with_retry_streaming(self, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
def update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
def execute(
prompt: str,
model: str,
api_key: Optional[str],
max_tokens: int = 256,
temperature: float = 0.0,
top_p: float = 1.0,
) -> Any:
"""Execute LLM query"""
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
requestBody = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
requestHeaders = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
return response.text
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
def completion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.Completion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call for streaming."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: Fireworks,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -2390,6 +2390,21 @@ calc = ["shapely"]
s3 = ["boto3 (>=1.3.1)"]
test = ["Fiona[s3]", "pytest (>=7)", "pytest-cov", "pytz"]
[[package]]
name = "fireworks-ai"
version = "0.4.1"
description = "Python client library for the Fireworks.ai Generative AI Platform"
optional = true
python-versions = ">=3.9"
files = [
{file = "fireworks_ai-0.4.1-py3-none-any.whl", hash = "sha256:6ac124ffcd783442b4569e4127adafb0bde6861b6ce5d7a7d162d3920e7cc4e9"},
]
[package.dependencies]
httpx = "*"
httpx-sse = "*"
pydantic = "*"
[[package]]
name = "flatbuffers"
version = "23.5.26"
@@ -3171,6 +3186,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "httpx-sse"
version = "0.3.1"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = true
python-versions = ">=3.7"
files = [
{file = "httpx-sse-0.3.1.tar.gz", hash = "sha256:3bb3289b2867f50cbdb2fee3eeeefecb1e86653122e164faac0023f1ffc88aea"},
{file = "httpx_sse-0.3.1-py3-none-any.whl", hash = "sha256:7376dd88732892f9b6b549ac0ad05a8e2341172fe7dcf9f8f9c8050934297316"},
]
[[package]]
name = "huggingface-hub"
version = "0.16.4"
@@ -5818,7 +5844,7 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
@@ -8820,7 +8846,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
typing-extensions = ">=4.2.0"
[package.extras]
@@ -10622,4 +10648,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"
content-hash = "2d24ce7353641663405c132acfbd45492f56c4e53372eac4b698a6ace1eb27b7"

View File

@@ -132,6 +132,7 @@ sqlite-vss = {version = "^0.1.2", optional = true}
anyio = "<4.0"
jsonpatch = "^1.33"
timescale-vector = {version = "^0.0.1", optional = true}
fireworks-ai = {version = "^0.4.1", optional = true, python = ">=3.9"}
[tool.poetry.group.test.dependencies]

View File

@@ -0,0 +1,106 @@
"""Test ChatFireworks wrapper."""
import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.fireworks import ChatFireworks
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
def test_chat_fireworks() -> None:
"""Test ChatFireworks wrapper."""
chat = ChatFireworks()
message = HumanMessage(content="What is the weather in Redwood City, CA today")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_fireworks_model() -> None:
"""Test ChatFireworks wrapper handles model_name."""
chat = ChatFireworks(model="foo")
assert chat.model == "foo"
def test_chat_fireworks_system_message() -> None:
"""Test ChatFireworks wrapper with system message."""
chat = ChatFireworks()
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_fireworks_generate() -> None:
"""Test ChatFireworks wrapper with generate."""
chat = ChatFireworks(model_kwargs={"n": 2})
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_chat_fireworks_multiple_completions() -> None:
"""Test ChatFireworks wrapper with multiple completions."""
chat = ChatFireworks(model_kwargs={"n": 5})
message = HumanMessage(content="Hello")
response = chat._generate([message])
assert isinstance(response, ChatResult)
assert len(response.generations) == 5
for generation in response.generations:
assert isinstance(generation.message, BaseMessage)
assert isinstance(generation.message.content, str)
def test_chat_fireworks_llm_output_contains_model_id() -> None:
"""Test llm_output contains model_id."""
chat = ChatFireworks()
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model"] == chat.model
def test_fireworks_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatFireworks()
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_chat_fireworks_agenerate() -> None:
"""Test ChatFireworks wrapper with generate."""
chat = ChatFireworks(model_kwargs={"n": 2})
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_fireworks_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatFireworks()
async for token in llm.astream("Who's the best quarterback in the NFL?"):
assert isinstance(token.content, str)

View File

@@ -1,31 +1,20 @@
"""Test Fireworks AI API Wrapper."""
from pathlib import Path
import pytest
from langchain.chains import RetrievalQA
from langchain.chains.llm import LLMChain
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAIChat
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.loading import load_llm
from typing import Generator
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms.fireworks import Fireworks
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import DeepLake
import pytest
def test_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = Fireworks(
model_id="accounts/fireworks/models/fireworks-llama-v2-13b-chat", max_tokens=900
)
output = llm("What is the weather in NYC")
llm = Fireworks()
output = llm("Who's the best quarterback in the NFL?")
assert isinstance(output, str)
@@ -44,36 +33,10 @@ def test_fireworks_in_chain() -> None:
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_openai_chat_async_generate() -> None:
"""Test async chat."""
llm = OpenAIChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
def test_fireworks_model_param() -> None:
"""Tests model parameters for Fireworks"""
llm = Fireworks(model="foo")
assert llm.model_id == "foo"
llm = Fireworks(model_id="foo")
assert llm.model_id == "foo"
def test_fireworkschat_model_param() -> None:
"""Tests model parameters for FireworksChat"""
llm = FireworksChat(model="foo")
assert llm.model_id == "foo"
llm = FireworksChat(model_id="foo")
assert llm.model_id == "foo"
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Fireworks LLM."""
llm = Fireworks(max_tokens=10)
llm.save(file_path=tmp_path / "fireworks.yaml")
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
assert loaded_llm == llm
assert llm.model == "foo"
def test_fireworks_multiple_prompts() -> None:
@@ -85,76 +48,39 @@ def test_fireworks_multiple_prompts() -> None:
assert len(output.generations) == 2
def test_fireworks_chat() -> None:
"""Test FireworksChat."""
llm = FireworksChat()
output = llm("Name me 3 quick facts about the New England Patriots")
assert isinstance(output, str)
async def test_fireworks_agenerate() -> None:
def test_fireworks_streaming() -> None:
"""Test stream completion."""
llm = Fireworks()
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
generator = llm.stream("Who's the best quarterback in the NFL?")
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_fireworks_streaming_async() -> None:
"""Test stream completion."""
llm = Fireworks()
async for token in llm.astream("Who's the best quarterback in the NFL?"):
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_fireworks_async_agenerate() -> None:
"""Test async."""
llm = Fireworks()
output = await llm.agenerate(["What is the best city to live in California?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_fireworks_multiple_prompts_async_agenerate() -> None:
llm = Fireworks()
output = await llm.agenerate(
["How is the weather in New York today?", "I'm pickle rick"]
)
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
async def test_fireworkschat_agenerate() -> None:
llm = FireworksChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 1
def test_fireworkschat_chain() -> None:
embeddings = OpenAIEmbeddings()
loader = TextLoader(
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
db = DeepLake(
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
)
db.add_documents(docs)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search(query)
qa = RetrievalQA.from_chain_type(
llm=FireworksChat(),
chain_type="stuff",
retriever=db.as_retriever(),
)
query = "What did the president say about Ketanji Brown Jackson"
output = qa.run(query)
assert isinstance(output, str)
_EXPECTED_NUM_TOKENS = {
"accounts/fireworks/models/fireworks-llama-v2-13b": 17,
"accounts/fireworks/models/fireworks-llama-v2-7b": 17,
"accounts/fireworks/models/fireworks-llama-v2-13b-chat": 17,
"accounts/fireworks/models/fireworks-llama-v2-7b-chat": 17,
}
_MODELS = models = [
"accounts/fireworks/models/fireworks-llama-v2-13b",
"accounts/fireworks/models/fireworks-llama-v2-7b",
"accounts/fireworks/models/fireworks-llama-v2-13b-chat",
"accounts/fireworks/models/fireworks-llama-v2-7b-chat",
]
@pytest.mark.parametrize("model", _MODELS)
def test_fireworks_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = Fireworks(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]