mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
Description * Refactor Fireworks within Langchain LLMs. * Remove FireworksChat within Langchain LLMs. * Add ChatFireworks (which uses chat completion api) to Langchain chat models. * Users have to install `fireworks-ai` and register an api key to use the api. Issue - Not applicable Dependencies - None Tag maintainer - @rlancemartin @baskaryan
This commit is contained in:
264
libs/langchain/langchain/chat_models/fireworks.py
Normal file
264
libs/langchain/langchain/chat_models/fireworks.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import fireworks
|
||||
import fireworks.client
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
"""Convert a delta response to a message chunk."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
additional_kwargs = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict.name)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dict response to a message."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
content = _dict.content
|
||||
additional_kwargs = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=content, name=_dict.name)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
class ChatFireworks(BaseChatModel):
|
||||
"""Fireworks Chat models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 20
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key in environment."""
|
||||
fireworks_api_key = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
fireworks.client.api_key = fireworks_api_key
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(self, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(self, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return llm_outputs[0]
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(res.message)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
)
|
||||
generations.append(gen)
|
||||
llm_output = {"model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]]]:
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
default_chunk_class = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for chunk in completion_with_retry(self, **params):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
default_chunk_class = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for chunk in await acompletion_with_retry_streaming(self, **params):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call for streaming."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define retry mechanism."""
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
@@ -44,7 +44,7 @@ from langchain.llms.deepinfra import DeepInfra
|
||||
from langchain.llms.deepsparse import DeepSparse
|
||||
from langchain.llms.edenai import EdenAI
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||
from langchain.llms.fireworks import Fireworks
|
||||
from langchain.llms.forefrontai import ForefrontAI
|
||||
from langchain.llms.google_palm import GooglePalm
|
||||
from langchain.llms.gooseai import GooseAI
|
||||
|
@@ -1,377 +1,220 @@
|
||||
"""Wrapper around Fireworks APIs"""
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
|
||||
import fireworks
|
||||
import fireworks.client
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class BaseFireworks(BaseLLM):
|
||||
"""Wrapper around Fireworks large language models."""
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
return GenerationChunk(
|
||||
text=stream_response.choices[0].text,
|
||||
generation_info=dict(
|
||||
finish_reason=stream_response.choices[0].finish_reason,
|
||||
logprobs=stream_response.choices[0].logprobs,
|
||||
),
|
||||
)
|
||||
|
||||
model_id: str = Field("accounts/fireworks/models/llama-v2-7b-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
class Fireworks(LLM):
|
||||
"""Fireworks models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
|
||||
fireworks_api_key: Optional[str] = None
|
||||
"""Api key to use fireworks API"""
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
def __new__(cls, **data: Any) -> Any:
|
||||
"""Initialize the Fireworks object."""
|
||||
data.get("model_id", "")
|
||||
return super().__new__(cls)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
max_retries: int = 20
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
"""Validate that api key in environment."""
|
||||
fireworks_api_key = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
fireworks.client.api_key = fireworks_api_key
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint with k unique prompts.
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The full LLM output.
|
||||
"""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = completion_with_retry(self, prompt=prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
def get_batch_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
sub_prompts = [
|
||||
prompts[i : i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||
) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i : (i + 1)]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice,
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks"
|
||||
|
||||
|
||||
class FireworksChat(BaseLLM):
|
||||
"""Wrapper around Fireworks Chat large language models.
|
||||
To use, you should have the ``fireworksai`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import FireworksChat
|
||||
fireworkschat = FireworksChat(model_id=""llama-v2-13b-chat"")
|
||||
"""
|
||||
|
||||
model_id: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 6
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
"""Maximum number of retries to make when generating."""
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment"""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_chat_params(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> Tuple:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError(
|
||||
f"FireworksChat currently only supports single prompt, got {prompts}"
|
||||
)
|
||||
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
||||
params: Dict[str, Any] = {**{"model": self.model_id}}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
return messages, params
|
||||
|
||||
def _generate(
|
||||
def _call(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
response = completion_with_retry(self, **params)
|
||||
|
||||
async def _agenerate(
|
||||
return response.choices[0].text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
response = await acompletion_with_retry(self, **params)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
return response.choices[0].text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for stream_resp in completion_with_retry(self, **params):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
|
||||
class Fireworks(BaseFireworks):
|
||||
"""Wrapper around Fireworks large language models.
|
||||
To use, you should have the ``fireworks`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import fireworks
|
||||
llm = Fireworks(model_id="llama-v2-13b")
|
||||
"""
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for stream_resp in await acompletion_with_retry_streaming(self, **params):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
def update_token_usage(
|
||||
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Update token usage."""
|
||||
_keys_to_use = keys.intersection(response)
|
||||
for _key in _keys_to_use:
|
||||
if _key not in token_usage:
|
||||
token_usage[_key] = response["usage"][_key]
|
||||
else:
|
||||
token_usage[_key] += response["usage"][_key]
|
||||
|
||||
|
||||
def execute(
|
||||
prompt: str,
|
||||
model: str,
|
||||
api_key: Optional[str],
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
) -> Any:
|
||||
"""Execute LLM query"""
|
||||
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
|
||||
requestBody = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
requestHeaders = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
|
||||
return response.text
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
llm: Fireworks,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
llm: Fireworks,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: Fireworks,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call for streaming."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: Fireworks,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define retry mechanism."""
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
34
libs/langchain/poetry.lock
generated
34
libs/langchain/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -2390,6 +2390,21 @@ calc = ["shapely"]
|
||||
s3 = ["boto3 (>=1.3.1)"]
|
||||
test = ["Fiona[s3]", "pytest (>=7)", "pytest-cov", "pytz"]
|
||||
|
||||
[[package]]
|
||||
name = "fireworks-ai"
|
||||
version = "0.4.1"
|
||||
description = "Python client library for the Fireworks.ai Generative AI Platform"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "fireworks_ai-0.4.1-py3-none-any.whl", hash = "sha256:6ac124ffcd783442b4569e4127adafb0bde6861b6ce5d7a7d162d3920e7cc4e9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "*"
|
||||
httpx-sse = "*"
|
||||
pydantic = "*"
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "23.5.26"
|
||||
@@ -3171,6 +3186,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.3.1"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "httpx-sse-0.3.1.tar.gz", hash = "sha256:3bb3289b2867f50cbdb2fee3eeeefecb1e86653122e164faac0023f1ffc88aea"},
|
||||
{file = "httpx_sse-0.3.1-py3-none-any.whl", hash = "sha256:7376dd88732892f9b6b549ac0ad05a8e2341172fe7dcf9f8f9c8050934297316"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.16.4"
|
||||
@@ -5818,7 +5844,7 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
@@ -8820,7 +8846,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
@@ -10622,4 +10648,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"
|
||||
content-hash = "2d24ce7353641663405c132acfbd45492f56c4e53372eac4b698a6ace1eb27b7"
|
||||
|
@@ -132,6 +132,7 @@ sqlite-vss = {version = "^0.1.2", optional = true}
|
||||
anyio = "<4.0"
|
||||
jsonpatch = "^1.33"
|
||||
timescale-vector = {version = "^0.0.1", optional = true}
|
||||
fireworks-ai = {version = "^0.4.1", optional = true, python = ">=3.9"}
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
|
@@ -0,0 +1,106 @@
|
||||
"""Test ChatFireworks wrapper."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.fireworks import ChatFireworks
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
def test_chat_fireworks() -> None:
|
||||
"""Test ChatFireworks wrapper."""
|
||||
chat = ChatFireworks()
|
||||
message = HumanMessage(content="What is the weather in Redwood City, CA today")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_model() -> None:
|
||||
"""Test ChatFireworks wrapper handles model_name."""
|
||||
chat = ChatFireworks(model="foo")
|
||||
assert chat.model == "foo"
|
||||
|
||||
|
||||
def test_chat_fireworks_system_message() -> None:
|
||||
"""Test ChatFireworks wrapper with system message."""
|
||||
chat = ChatFireworks()
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_generate() -> None:
|
||||
"""Test ChatFireworks wrapper with generate."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_chat_fireworks_multiple_completions() -> None:
|
||||
"""Test ChatFireworks wrapper with multiple completions."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 5})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
||||
"""Test llm_output contains model_id."""
|
||||
chat = ChatFireworks()
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model"] == chat.model
|
||||
|
||||
|
||||
def test_fireworks_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_fireworks_agenerate() -> None:
|
||||
"""Test ChatFireworks wrapper with generate."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||
assert isinstance(token.content, str)
|
@@ -1,31 +1,20 @@
|
||||
"""Test Fireworks AI API Wrapper."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import OpenAIChat
|
||||
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||
from langchain.llms.loading import load_llm
|
||||
from typing import Generator
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms.fireworks import Fireworks
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import DeepLake
|
||||
import pytest
|
||||
|
||||
|
||||
def test_fireworks_call() -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
llm = Fireworks(
|
||||
model_id="accounts/fireworks/models/fireworks-llama-v2-13b-chat", max_tokens=900
|
||||
)
|
||||
output = llm("What is the weather in NYC")
|
||||
llm = Fireworks()
|
||||
output = llm("Who's the best quarterback in the NFL?")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@@ -44,36 +33,10 @@ def test_fireworks_in_chain() -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_async_generate() -> None:
|
||||
"""Test async chat."""
|
||||
llm = OpenAIChat(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_fireworks_model_param() -> None:
|
||||
"""Tests model parameters for Fireworks"""
|
||||
llm = Fireworks(model="foo")
|
||||
assert llm.model_id == "foo"
|
||||
llm = Fireworks(model_id="foo")
|
||||
assert llm.model_id == "foo"
|
||||
|
||||
|
||||
def test_fireworkschat_model_param() -> None:
|
||||
"""Tests model parameters for FireworksChat"""
|
||||
llm = FireworksChat(model="foo")
|
||||
assert llm.model_id == "foo"
|
||||
llm = FireworksChat(model_id="foo")
|
||||
assert llm.model_id == "foo"
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an Fireworks LLM."""
|
||||
llm = Fireworks(max_tokens=10)
|
||||
llm.save(file_path=tmp_path / "fireworks.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
|
||||
assert loaded_llm == llm
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
def test_fireworks_multiple_prompts() -> None:
|
||||
@@ -85,76 +48,39 @@ def test_fireworks_multiple_prompts() -> None:
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_fireworks_chat() -> None:
|
||||
"""Test FireworksChat."""
|
||||
llm = FireworksChat()
|
||||
output = llm("Name me 3 quick facts about the New England Patriots")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_fireworks_agenerate() -> None:
|
||||
def test_fireworks_streaming() -> None:
|
||||
"""Test stream completion."""
|
||||
llm = Fireworks()
|
||||
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
|
||||
generator = llm.stream("Who's the best quarterback in the NFL?")
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_streaming_async() -> None:
|
||||
"""Test stream completion."""
|
||||
llm = Fireworks()
|
||||
|
||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_async_agenerate() -> None:
|
||||
"""Test async."""
|
||||
llm = Fireworks()
|
||||
output = await llm.agenerate(["What is the best city to live in California?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_multiple_prompts_async_agenerate() -> None:
|
||||
llm = Fireworks()
|
||||
output = await llm.agenerate(
|
||||
["How is the weather in New York today?", "I'm pickle rick"]
|
||||
)
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
async def test_fireworkschat_agenerate() -> None:
|
||||
llm = FireworksChat(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 1
|
||||
|
||||
|
||||
def test_fireworkschat_chain() -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
loader = TextLoader(
|
||||
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
|
||||
)
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
db = DeepLake(
|
||||
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
|
||||
)
|
||||
db.add_documents(docs)
|
||||
|
||||
query = "What did the president say about Ketanji Brown Jackson"
|
||||
docs = db.similarity_search(query)
|
||||
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=FireworksChat(),
|
||||
chain_type="stuff",
|
||||
retriever=db.as_retriever(),
|
||||
)
|
||||
query = "What did the president say about Ketanji Brown Jackson"
|
||||
output = qa.run(query)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
_EXPECTED_NUM_TOKENS = {
|
||||
"accounts/fireworks/models/fireworks-llama-v2-13b": 17,
|
||||
"accounts/fireworks/models/fireworks-llama-v2-7b": 17,
|
||||
"accounts/fireworks/models/fireworks-llama-v2-13b-chat": 17,
|
||||
"accounts/fireworks/models/fireworks-llama-v2-7b-chat": 17,
|
||||
}
|
||||
|
||||
_MODELS = models = [
|
||||
"accounts/fireworks/models/fireworks-llama-v2-13b",
|
||||
"accounts/fireworks/models/fireworks-llama-v2-7b",
|
||||
"accounts/fireworks/models/fireworks-llama-v2-13b-chat",
|
||||
"accounts/fireworks/models/fireworks-llama-v2-7b-chat",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _MODELS)
|
||||
def test_fireworks_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = Fireworks(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||
|
Reference in New Issue
Block a user