mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 21:37:20 +00:00
community[minor]: Integration for Friendli
LLM and ChatFriendli
ChatModel. (#17913)
## Description - Add [Friendli](https://friendli.ai/) integration for `Friendli` LLM and `ChatFriendli` chat model. - Unit tests and integration tests corresponding to this change are added. - Documentations corresponding to this change are added. ## Dependencies - Optional dependency [`friendli-client`](https://pypi.org/project/friendli-client/) package is added only for those who use `Frienldi` or `ChatFriendli` model. ## Twitter handle - https://twitter.com/friendliai
This commit is contained in:
@@ -30,6 +30,7 @@ from langchain_community.chat_models.ernie import ErnieBotChat
|
||||
from langchain_community.chat_models.everlyai import ChatEverlyAI
|
||||
from langchain_community.chat_models.fake import FakeListChatModel
|
||||
from langchain_community.chat_models.fireworks import ChatFireworks
|
||||
from langchain_community.chat_models.friendli import ChatFriendli
|
||||
from langchain_community.chat_models.gigachat import GigaChat
|
||||
from langchain_community.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain_community.chat_models.gpt_router import GPTRouter
|
||||
@@ -94,6 +95,7 @@ __all__ = [
|
||||
"ChatYandexGPT",
|
||||
"ChatBaichuan",
|
||||
"ChatHunyuan",
|
||||
"ChatFriendli",
|
||||
"GigaChat",
|
||||
"ChatSparkLLM",
|
||||
"VolcEngineMaasChat",
|
||||
|
217
libs/community/langchain_community/chat_models/friendli.py
Normal file
217
libs/community/langchain_community/chat_models/friendli.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.friendli import BaseFriendli
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
"""Get role of the message.
|
||||
|
||||
Args:
|
||||
message (BaseMessage): The message object.
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when the message is of an unknown type.
|
||||
|
||||
Returns:
|
||||
str: The role of the message.
|
||||
"""
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(message, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
"""Get a request of the Friendli chat API.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): Messages comprising the conversation so far.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The request for the Friendli chat API.
|
||||
"""
|
||||
return {
|
||||
"messages": [
|
||||
{"role": get_role(message), "content": message.content}
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ChatFriendli(BaseChatModel, BaseFriendli):
|
||||
"""Friendli LLM for chat.
|
||||
|
||||
``friendli-client`` package should be installed with `pip install friendli-client`.
|
||||
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
|
||||
personal access token for the ``friendli_token`` argument.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import FriendliChat
|
||||
|
||||
chat = Friendli(
|
||||
model="llama-2-13b-chat", friendli_token="YOUR FRIENDLI TOKEN"
|
||||
)
|
||||
chat.invoke("What is generative AI?")
|
||||
"""
|
||||
|
||||
model: str = "llama-2-13b-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"friendli_token": "FRIENDLI_TOKEN"}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Friendli completions API."""
|
||||
return {
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model": self.model, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "friendli-chat"
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop"] = self.stop
|
||||
else:
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = self.client.chat.completions.create(
|
||||
**get_chat_request(messages), stream=True, model=self.model, **params
|
||||
)
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = await self.async_client.chat.completions.create(
|
||||
**get_chat_request(messages), stream=True, model=self.model, **params
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
response = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": get_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
stream=False,
|
||||
model=self.model,
|
||||
**params,
|
||||
)
|
||||
|
||||
message = AIMessage(content=response.choices[0].message.content)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
response = await self.async_client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": get_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
stream=False,
|
||||
model=self.model,
|
||||
**params,
|
||||
)
|
||||
|
||||
message = AIMessage(content=response.choices[0].message.content)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
@@ -209,6 +209,12 @@ def _import_forefrontai() -> Type[BaseLLM]:
|
||||
return ForefrontAI
|
||||
|
||||
|
||||
def _import_friendli() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.friendli import Friendli
|
||||
|
||||
return Friendli
|
||||
|
||||
|
||||
def _import_gigachat() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.gigachat import GigaChat
|
||||
|
||||
@@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_fireworks()
|
||||
elif name == "ForefrontAI":
|
||||
return _import_forefrontai()
|
||||
elif name == "Friendli":
|
||||
return _import_friendli()
|
||||
elif name == "GigaChat":
|
||||
return _import_gigachat()
|
||||
elif name == "GooglePalm":
|
||||
@@ -827,6 +835,7 @@ __all__ = [
|
||||
"FakeListLLM",
|
||||
"Fireworks",
|
||||
"ForefrontAI",
|
||||
"Friendli",
|
||||
"GigaChat",
|
||||
"GPT4All",
|
||||
"GooglePalm",
|
||||
@@ -919,6 +928,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"edenai": _import_edenai,
|
||||
"fake-list": _import_fake,
|
||||
"forefrontai": _import_forefrontai,
|
||||
"friendli": _import_friendli,
|
||||
"giga-chat-model": _import_gigachat,
|
||||
"google_palm": _import_google_palm,
|
||||
"gooseai": _import_gooseai,
|
||||
|
350
libs/community/langchain_community/llms/friendli.py
Normal file
350
libs/community/langchain_community/llms/friendli.py
Normal file
@@ -0,0 +1,350 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils.env import get_from_dict_or_env
|
||||
from langchain_core.utils.utils import convert_to_secret_str
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(stream_response: Any) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
if stream_response.event == "token_sampled":
|
||||
return GenerationChunk(
|
||||
text=stream_response.text,
|
||||
generation_info={"token": str(stream_response.token)},
|
||||
)
|
||||
return GenerationChunk(text="")
|
||||
|
||||
|
||||
class BaseFriendli(Serializable):
|
||||
"""Base class of Friendli."""
|
||||
|
||||
# Friendli client.
|
||||
client: Any = Field(default=None, exclude=True)
|
||||
# Friendli Async client.
|
||||
async_client: Any = Field(default=None, exclude=True)
|
||||
# Model name to use.
|
||||
model: str = "mixtral-8x7b-instruct-v0-1"
|
||||
# Friendli personal access token to run as.
|
||||
friendli_token: Optional[SecretStr] = None
|
||||
# Friendli team ID to run as.
|
||||
friendli_team: Optional[str] = None
|
||||
# Whether to enable streaming mode.
|
||||
streaming: bool = False
|
||||
# Number between -2.0 and 2.0. Positive values penalizes tokens that have been
|
||||
# sampled, taking into account their frequency in the preceding text. This
|
||||
# penalization diminishes the model's tendency to reproduce identical lines
|
||||
# verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Number between -2.0 and 2.0. Positive values penalizes tokens that have been
|
||||
# sampled at least once in the existing text.
|
||||
presence_penalty: Optional[float] = None
|
||||
# The maximum number of tokens to generate. The length of your input tokens plus
|
||||
# `max_tokens` should not exceed the model's maximum length (e.g., 2048 for OpenAI
|
||||
# GPT-3)
|
||||
max_tokens: Optional[int] = None
|
||||
# When one of the stop phrases appears in the generation result, the API will stop
|
||||
# generation. The phrase is included in the generated result. If you are using
|
||||
# beam search, all of the active beams should contain the stop phrase to terminate
|
||||
# generation. Before checking whether a stop phrase is included in the result, the
|
||||
# phrase is converted into tokens.
|
||||
stop: Optional[List[str]] = None
|
||||
# Sampling temperature. Smaller temperature makes the generation result closer to
|
||||
# greedy, argmax (i.e., `top_k = 1`) sampling. If it is `None`, then 1.0 is used.
|
||||
temperature: Optional[float] = None
|
||||
# Tokens comprising the top `top_p` probability mass are kept for sampling. Numbers
|
||||
# between 0.0 (exclusive) and 1.0 (inclusive) are allowed. If it is `None`, then 1.0
|
||||
# is used by default.
|
||||
top_p: Optional[float] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate if personal access token is provided in environment."""
|
||||
try:
|
||||
import friendli
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import friendli-client python package. "
|
||||
"Please install it with `pip install friendli-client`."
|
||||
) from e
|
||||
|
||||
friendli_token = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "friendli_token", "FRIENDLI_TOKEN")
|
||||
)
|
||||
values["friendli_token"] = friendli_token
|
||||
friendli_token_str = friendli_token.get_secret_value()
|
||||
friendli_team = values["friendli_team"] or os.getenv("FRIENDLI_TEAM")
|
||||
values["friendli_team"] = friendli_team
|
||||
values["client"] = values["client"] or friendli.Friendli(
|
||||
token=friendli_token_str, team_id=friendli_team
|
||||
)
|
||||
values["async_client"] = values["async_client"] or friendli.AsyncFriendli(
|
||||
token=friendli_token_str, team_id=friendli_team
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class Friendli(LLM, BaseFriendli):
|
||||
"""Friendli LLM.
|
||||
|
||||
``friendli-client`` package should be installed with `pip install friendli-client`.
|
||||
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
|
||||
personal access token for the ``friendli_token`` argument.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Friendli
|
||||
|
||||
friendli = Friendli(
|
||||
model="mixtral-8x7b-instruct-v0-1", friendli_token="YOUR FRIENDLI TOKEN"
|
||||
)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"friendli_token": "FRIENDLI_TOKEN"}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Friendli completions API."""
|
||||
return {
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model": self.model, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "friendli"
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop"] = self.stop
|
||||
else:
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out Friendli's completions API.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt to generate completion for.
|
||||
stop (Optional[List[str]], optional): When one of the stop phrases appears
|
||||
in the generation result, the API will stop generation. The stop phrases
|
||||
are excluded from the result. If beam search is enabled, all of the
|
||||
active beams should contain the stop phrase to terminate generation.
|
||||
Before checking whether a stop phrase is included in the result, the
|
||||
phrase is converted into tokens. We recommend using stop_tokens because
|
||||
it is clearer. For example, after tokenization, phrases "clear" and
|
||||
" clear" can result in different token sequences due to the prepended
|
||||
space character. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated text output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = frienldi("Give me a recipe for the Old Fashioned cocktail.")
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
completion = self.client.completions.create(
|
||||
model=self.model, prompt=prompt, stream=False, **params
|
||||
)
|
||||
return completion.choices[0].text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out Friendli's completions API Asynchronously.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt to generate completion for.
|
||||
stop (Optional[List[str]], optional): When one of the stop phrases appears
|
||||
in the generation result, the API will stop generation. The stop phrases
|
||||
are excluded from the result. If beam search is enabled, all of the
|
||||
active beams should contain the stop phrase to terminate generation.
|
||||
Before checking whether a stop phrase is included in the result, the
|
||||
phrase is converted into tokens. We recommend using stop_tokens because
|
||||
it is clearer. For example, after tokenization, phrases "clear" and
|
||||
" clear" can result in different token sequences due to the prepended
|
||||
space character. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated text output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await frienldi("Tell me a joke.")
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
completion = await self.async_client.completions.create(
|
||||
model=self.model, prompt=prompt, stream=False, **params
|
||||
)
|
||||
return completion.choices[0].text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = self.client.completions.create(
|
||||
model=self.model, prompt=prompt, stream=True, **params
|
||||
)
|
||||
for line in stream:
|
||||
chunk = _stream_response_to_generation_chunk(line)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(line.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = await self.async_client.completions.create(
|
||||
model=self.model, prompt=prompt, stream=True, **params
|
||||
)
|
||||
async for line in stream:
|
||||
chunk = _stream_response_to_generation_chunk(line)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(line.text, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out Friendli's completions API with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt to generate completion for.
|
||||
stop (Optional[List[str]], optional): When one of the stop phrases appears
|
||||
in the generation result, the API will stop generation. The stop phrases
|
||||
are excluded from the result. If beam search is enabled, all of the
|
||||
active beams should contain the stop phrase to terminate generation.
|
||||
Before checking whether a stop phrase is included in the result, the
|
||||
phrase is converted into tokens. We recommend using stop_tokens because
|
||||
it is clearer. For example, after tokenization, phrases "clear" and
|
||||
" clear" can result in different token sequences due to the prepended
|
||||
space character. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated text output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = frienldi.generate(["Tell me a joke."])
|
||||
"""
|
||||
llm_output = {"model": self.model}
|
||||
if self.streaming:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return LLMResult(generations=[[generation]], llm_output=llm_output)
|
||||
|
||||
llm_result = super()._generate(prompts, stop, run_manager, **kwargs)
|
||||
llm_result.llm_output = llm_output
|
||||
return llm_result
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out Friendli's completions API asynchronously with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt to generate completion for.
|
||||
stop (Optional[List[str]], optional): When one of the stop phrases appears
|
||||
in the generation result, the API will stop generation. The stop phrases
|
||||
are excluded from the result. If beam search is enabled, all of the
|
||||
active beams should contain the stop phrase to terminate generation.
|
||||
Before checking whether a stop phrase is included in the result, the
|
||||
phrase is converted into tokens. We recommend using stop_tokens because
|
||||
it is clearer. For example, after tokenization, phrases "clear" and
|
||||
" clear" can result in different token sequences due to the prepended
|
||||
space character. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated text output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await frienldi.agenerate(
|
||||
["Give me a recipe for the Old Fashioned cocktail."]
|
||||
)
|
||||
"""
|
||||
llm_output = {"model": self.model}
|
||||
if self.streaming:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
generation = None
|
||||
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return LLMResult(generations=[[generation]], llm_output=llm_output)
|
||||
|
||||
llm_result = await super()._agenerate(prompts, stop, run_manager, **kwargs)
|
||||
llm_result.llm_output = llm_output
|
||||
return llm_result
|
3572
libs/community/poetry.lock
generated
3572
libs/community/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -96,6 +96,7 @@ oci = {version = "^2.119.1", optional = true}
|
||||
rdflib = {version = "7.0.0", optional = true}
|
||||
nvidia-riva-client = {version = "^2.14.0", optional = true}
|
||||
tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
|
||||
friendli-client = {version = "^1.2.4", optional = true}
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@@ -266,6 +267,7 @@ extended_testing = [
|
||||
"rdflib",
|
||||
"tidb-vector",
|
||||
"cloudpickle",
|
||||
"friendli-client"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@@ -0,0 +1,105 @@
|
||||
"""Test Friendli chat API."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages.ai import AIMessage
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
|
||||
from langchain_community.chat_models.friendli import ChatFriendli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def friendli_chat() -> ChatFriendli:
|
||||
"""Friendli LLM for chat."""
|
||||
return ChatFriendli(temperature=0, max_tokens=10)
|
||||
|
||||
|
||||
def test_friendli_call(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test call."""
|
||||
message = HumanMessage(content="What is generative AI?")
|
||||
output = friendli_chat([message])
|
||||
assert isinstance(output, AIMessage)
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
def test_friendli_invoke(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test invoke."""
|
||||
output = friendli_chat.invoke("What is generative AI?")
|
||||
assert isinstance(output, AIMessage)
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
async def test_friendli_ainvoke(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test async invoke."""
|
||||
output = await friendli_chat.ainvoke("What is generative AI?")
|
||||
assert isinstance(output, AIMessage)
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
def test_friendli_batch(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test batch."""
|
||||
outputs = friendli_chat.batch(["What is generative AI?", "What is generative AI?"])
|
||||
for output in outputs:
|
||||
assert isinstance(output, AIMessage)
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
async def test_friendli_abatch(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test async batch."""
|
||||
outputs = await friendli_chat.abatch(
|
||||
["What is generative AI?", "What is generative AI?"]
|
||||
)
|
||||
for output in outputs:
|
||||
assert isinstance(output, AIMessage)
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
def test_friendli_generate(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test generate."""
|
||||
message = HumanMessage(content="What is generative AI?")
|
||||
result = friendli_chat.generate([[message], [message]])
|
||||
assert isinstance(result, LLMResult)
|
||||
generations = result.generations
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
gen_ = generation[0]
|
||||
assert isinstance(gen_, Generation)
|
||||
text = gen_.text
|
||||
assert isinstance(text, str)
|
||||
generation_info = gen_.generation_info
|
||||
if generation_info is not None:
|
||||
assert "token" in generation_info
|
||||
|
||||
|
||||
async def test_friendli_agenerate(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test async generate."""
|
||||
message = HumanMessage(content="What is generative AI?")
|
||||
result = await friendli_chat.agenerate([[message], [message]])
|
||||
assert isinstance(result, LLMResult)
|
||||
generations = result.generations
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
gen_ = generation[0]
|
||||
assert isinstance(gen_, Generation)
|
||||
text = gen_.text
|
||||
assert isinstance(text, str)
|
||||
generation_info = gen_.generation_info
|
||||
if generation_info is not None:
|
||||
assert "token" in generation_info
|
||||
|
||||
|
||||
def test_friendli_stream(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test stream."""
|
||||
stream = friendli_chat.stream("Say hello world.")
|
||||
for chunk in stream:
|
||||
assert isinstance(chunk, AIMessage)
|
||||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
async def test_friendli_astream(friendli_chat: ChatFriendli) -> None:
|
||||
"""Test async stream."""
|
||||
stream = friendli_chat.astream("Say hello world.")
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk, AIMessage)
|
||||
assert isinstance(chunk.content, str)
|
91
libs/community/tests/integration_tests/llms/test_friendli.py
Normal file
91
libs/community/tests/integration_tests/llms/test_friendli.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test Friendli API."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
|
||||
from langchain_community.llms.friendli import Friendli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def friendli_llm() -> Friendli:
|
||||
"""Friendli LLM."""
|
||||
return Friendli(temperature=0, max_tokens=10)
|
||||
|
||||
|
||||
def test_friendli_call(friendli_llm: Friendli) -> None:
|
||||
"""Test call."""
|
||||
output = friendli_llm("Say hello world.")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_friendli_invoke(friendli_llm: Friendli) -> None:
|
||||
"""Test invoke."""
|
||||
output = friendli_llm.invoke("Say hello world.")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_friendli_ainvoke(friendli_llm: Friendli) -> None:
|
||||
"""Test async invoke."""
|
||||
output = await friendli_llm.ainvoke("Say hello world.")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_friendli_batch(friendli_llm: Friendli) -> None:
|
||||
"""Test batch."""
|
||||
outputs = friendli_llm.batch(["Say hello world.", "Say bye world."])
|
||||
for output in outputs:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_friendli_abatch(friendli_llm: Friendli) -> None:
|
||||
"""Test async batch."""
|
||||
outputs = await friendli_llm.abatch(["Say hello world.", "Say bye world."])
|
||||
for output in outputs:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_friendli_generate(friendli_llm: Friendli) -> None:
|
||||
"""Test generate."""
|
||||
result = friendli_llm.generate(["Say hello world.", "Say bye world."])
|
||||
assert isinstance(result, LLMResult)
|
||||
generations = result.generations
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
gen_ = generation[0]
|
||||
assert isinstance(gen_, Generation)
|
||||
text = gen_.text
|
||||
assert isinstance(text, str)
|
||||
generation_info = gen_.generation_info
|
||||
if generation_info is not None:
|
||||
assert "token" in generation_info
|
||||
|
||||
|
||||
async def test_friendli_agenerate(friendli_llm: Friendli) -> None:
|
||||
"""Test async generate."""
|
||||
result = await friendli_llm.agenerate(["Say hello world.", "Say bye world."])
|
||||
assert isinstance(result, LLMResult)
|
||||
generations = result.generations
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
gen_ = generation[0]
|
||||
assert isinstance(gen_, Generation)
|
||||
text = gen_.text
|
||||
assert isinstance(text, str)
|
||||
generation_info = gen_.generation_info
|
||||
if generation_info is not None:
|
||||
assert "token" in generation_info
|
||||
|
||||
|
||||
def test_friendli_stream(friendli_llm: Friendli) -> None:
|
||||
"""Test stream."""
|
||||
stream = friendli_llm.stream("Say hello world.")
|
||||
for chunk in stream:
|
||||
assert isinstance(chunk, str)
|
||||
|
||||
|
||||
async def test_friendli_astream(friendli_llm: Friendli) -> None:
|
||||
"""Test async stream."""
|
||||
stream = friendli_llm.astream("Say hello world.")
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk, str)
|
197
libs/community/tests/unit_tests/chat_models/test_friendli.py
Normal file
197
libs/community/tests/unit_tests/chat_models/test_friendli.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Test Friendli LLM for chat."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_community.adapters.openai import aenumerate
|
||||
from langchain_community.chat_models import ChatFriendli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_friendli_client() -> Mock:
|
||||
"""Mock instance of Friendli client."""
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_friendli_async_client() -> AsyncMock:
|
||||
"""Mock instance of Friendli async client."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_friendli(
|
||||
mock_friendli_client: Mock, mock_friendli_async_client: AsyncMock
|
||||
) -> ChatFriendli:
|
||||
"""Friendli LLM for chat with mock clients."""
|
||||
return ChatFriendli(
|
||||
friendli_token=SecretStr("personal-access-token"),
|
||||
client=mock_friendli_client,
|
||||
async_client=mock_friendli_async_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_token_is_secret_string(capsys: CaptureFixture) -> None:
|
||||
"""Test if friendli token is stored as a SecretStr."""
|
||||
fake_token_value = "personal-access-token"
|
||||
chat = ChatFriendli(friendli_token=fake_token_value)
|
||||
assert isinstance(chat.friendli_token, SecretStr)
|
||||
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||||
print(chat.friendli_token, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_token_read_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test if friendli token can be parsed from environment."""
|
||||
fake_token_value = "personal-access-token"
|
||||
monkeypatch.setenv("FRIENDLI_TOKEN", fake_token_value)
|
||||
chat = ChatFriendli()
|
||||
assert isinstance(chat.friendli_token, SecretStr)
|
||||
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||||
print(chat.friendli_token, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_invoke(
|
||||
mock_friendli_client: Mock, chat_friendli: ChatFriendli
|
||||
) -> None:
|
||||
"""Test invocation with friendli."""
|
||||
mock_message = Mock()
|
||||
mock_message.content = "Hello Friendli"
|
||||
mock_message.role = "assistant"
|
||||
mock_choice = Mock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_friendli_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
result = chat_friendli.invoke("Hello langchain")
|
||||
assert result.content == "Hello Friendli"
|
||||
mock_friendli_client.chat.completions.create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello langchain"}],
|
||||
stream=False,
|
||||
model=chat_friendli.model,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
async def test_friendli_ainvoke(
|
||||
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
|
||||
) -> None:
|
||||
"""Test async invocation with friendli."""
|
||||
mock_message = Mock()
|
||||
mock_message.content = "Hello Friendli"
|
||||
mock_message.role = "assistant"
|
||||
mock_choice = Mock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_friendli_async_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
result = await chat_friendli.ainvoke("Hello langchain")
|
||||
assert result.content == "Hello Friendli"
|
||||
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
|
||||
messages=[{"role": "user", "content": "Hello langchain"}],
|
||||
stream=False,
|
||||
model=chat_friendli.model,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_stream(
|
||||
mock_friendli_client: Mock, chat_friendli: ChatFriendli
|
||||
) -> None:
|
||||
"""Test stream with friendli."""
|
||||
mock_delta_0 = Mock()
|
||||
mock_delta_0.content = "Hello "
|
||||
mock_delta_1 = Mock()
|
||||
mock_delta_1.content = "Friendli"
|
||||
mock_choice_0 = Mock()
|
||||
mock_choice_0.delta = mock_delta_0
|
||||
mock_choice_1 = Mock()
|
||||
mock_choice_1.delta = mock_delta_1
|
||||
mock_chunk_0 = Mock()
|
||||
mock_chunk_0.choices = [mock_choice_0]
|
||||
mock_chunk_1 = Mock()
|
||||
mock_chunk_1.choices = [mock_choice_1]
|
||||
mock_stream = MagicMock()
|
||||
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream.__iter__.return_value = mock_chunks
|
||||
|
||||
mock_friendli_client.chat.completions.create.return_value = mock_stream
|
||||
stream = chat_friendli.stream("Hello langchain")
|
||||
for i, chunk in enumerate(stream):
|
||||
assert chunk.content == mock_chunks[i].choices[0].delta.content
|
||||
|
||||
mock_friendli_client.chat.completions.create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello langchain"}],
|
||||
stream=True,
|
||||
model=chat_friendli.model,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
async def test_friendli_astream(
|
||||
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
|
||||
) -> None:
|
||||
"""Test async stream with friendli."""
|
||||
mock_delta_0 = Mock()
|
||||
mock_delta_0.content = "Hello "
|
||||
mock_delta_1 = Mock()
|
||||
mock_delta_1.content = "Friendli"
|
||||
mock_choice_0 = Mock()
|
||||
mock_choice_0.delta = mock_delta_0
|
||||
mock_choice_1 = Mock()
|
||||
mock_choice_1.delta = mock_delta_1
|
||||
mock_chunk_0 = Mock()
|
||||
mock_chunk_0.choices = [mock_choice_0]
|
||||
mock_chunk_1 = Mock()
|
||||
mock_chunk_1.choices = [mock_choice_1]
|
||||
mock_stream = AsyncMock()
|
||||
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream.__aiter__.return_value = mock_chunks
|
||||
|
||||
mock_friendli_async_client.chat.completions.create.return_value = mock_stream
|
||||
stream = chat_friendli.astream("Hello langchain")
|
||||
async for i, chunk in aenumerate(stream):
|
||||
assert chunk.content == mock_chunks[i].choices[0].delta.content
|
||||
|
||||
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
|
||||
messages=[{"role": "user", "content": "Hello langchain"}],
|
||||
stream=True,
|
||||
model=chat_friendli.model,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
@@ -43,6 +43,7 @@ EXPECTED_ALL = [
|
||||
"ChatZhipuAI",
|
||||
"ChatPerplexity",
|
||||
"ChatKinetica",
|
||||
"ChatFriendli",
|
||||
]
|
||||
|
||||
|
||||
|
179
libs/community/tests/unit_tests/llms/test_friendli.py
Normal file
179
libs/community/tests/unit_tests/llms/test_friendli.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Test Friendli LLM."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_community.adapters.openai import aenumerate
|
||||
from langchain_community.llms.friendli import Friendli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_friendli_client() -> Mock:
|
||||
"""Mock instance of Friendli client."""
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_friendli_async_client() -> AsyncMock:
|
||||
"""Mock instance of Friendli async client."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def friendli_llm(
|
||||
mock_friendli_client: Mock, mock_friendli_async_client: AsyncMock
|
||||
) -> Friendli:
|
||||
"""Friendli LLM with mock clients."""
|
||||
return Friendli(
|
||||
friendli_token=SecretStr("personal-access-token"),
|
||||
client=mock_friendli_client,
|
||||
async_client=mock_friendli_async_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_token_is_secret_string(capsys: CaptureFixture) -> None:
|
||||
"""Test if friendli token is stored as a SecretStr."""
|
||||
fake_token_value = "personal-access-token"
|
||||
chat = Friendli(friendli_token=fake_token_value)
|
||||
assert isinstance(chat.friendli_token, SecretStr)
|
||||
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||||
print(chat.friendli_token, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_token_read_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test if friendli token can be parsed from environment."""
|
||||
fake_token_value = "personal-access-token"
|
||||
monkeypatch.setenv("FRIENDLI_TOKEN", fake_token_value)
|
||||
chat = Friendli()
|
||||
assert isinstance(chat.friendli_token, SecretStr)
|
||||
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||||
print(chat.friendli_token, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_invoke(mock_friendli_client: Mock, friendli_llm: Friendli) -> None:
|
||||
"""Test invocation with friendli."""
|
||||
mock_choice = Mock()
|
||||
mock_choice.text = "Hello Friendli"
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_friendli_client.completions.create.return_value = mock_response
|
||||
|
||||
result = friendli_llm.invoke("Hello langchain")
|
||||
assert result == "Hello Friendli"
|
||||
mock_friendli_client.completions.create.assert_called_once_with(
|
||||
model=friendli_llm.model,
|
||||
prompt="Hello langchain",
|
||||
stream=False,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
async def test_friendli_ainvoke(
|
||||
mock_friendli_async_client: AsyncMock, friendli_llm: Friendli
|
||||
) -> None:
|
||||
"""Test async invocation with friendli."""
|
||||
mock_choice = Mock()
|
||||
mock_choice.text = "Hello Friendli"
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_friendli_async_client.completions.create.return_value = mock_response
|
||||
|
||||
result = await friendli_llm.ainvoke("Hello langchain")
|
||||
assert result == "Hello Friendli"
|
||||
mock_friendli_async_client.completions.create.assert_awaited_once_with(
|
||||
model=friendli_llm.model,
|
||||
prompt="Hello langchain",
|
||||
stream=False,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
def test_friendli_stream(mock_friendli_client: Mock, friendli_llm: Friendli) -> None:
|
||||
"""Test stream with friendli."""
|
||||
mock_chunk_0 = Mock()
|
||||
mock_chunk_0.event = "token_sampled"
|
||||
mock_chunk_0.text = "Hello "
|
||||
mock_chunk_0.token = 0
|
||||
mock_chunk_1 = Mock()
|
||||
mock_chunk_1.event = "token_sampled"
|
||||
mock_chunk_1.text = "Friendli"
|
||||
mock_chunk_1.token = 1
|
||||
mock_stream = MagicMock()
|
||||
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream.__iter__.return_value = mock_chunks
|
||||
|
||||
mock_friendli_client.completions.create.return_value = mock_stream
|
||||
stream = friendli_llm.stream("Hello langchain")
|
||||
for i, chunk in enumerate(stream):
|
||||
assert chunk == mock_chunks[i].text
|
||||
|
||||
mock_friendli_client.completions.create.assert_called_once_with(
|
||||
model=friendli_llm.model,
|
||||
prompt="Hello langchain",
|
||||
stream=True,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("friendli")
|
||||
async def test_friendli_astream(
|
||||
mock_friendli_async_client: AsyncMock, friendli_llm: Friendli
|
||||
) -> None:
|
||||
"""Test async stream with friendli."""
|
||||
mock_chunk_0 = Mock()
|
||||
mock_chunk_0.event = "token_sampled"
|
||||
mock_chunk_0.text = "Hello "
|
||||
mock_chunk_0.token = 0
|
||||
mock_chunk_1 = Mock()
|
||||
mock_chunk_1.event = "token_sampled"
|
||||
mock_chunk_1.text = "Friendli"
|
||||
mock_chunk_1.token = 1
|
||||
mock_stream = AsyncMock()
|
||||
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||||
mock_stream.__aiter__.return_value = mock_chunks
|
||||
|
||||
mock_friendli_async_client.completions.create.return_value = mock_stream
|
||||
stream = friendli_llm.astream("Hello langchain")
|
||||
async for i, chunk in aenumerate(stream):
|
||||
assert chunk == mock_chunks[i].text
|
||||
|
||||
mock_friendli_async_client.completions.create.assert_awaited_once_with(
|
||||
model=friendli_llm.model,
|
||||
prompt="Hello langchain",
|
||||
stream=True,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
max_tokens=None,
|
||||
stop=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
)
|
@@ -30,6 +30,7 @@ EXPECT_ALL = [
|
||||
"FakeListLLM",
|
||||
"Fireworks",
|
||||
"ForefrontAI",
|
||||
"Friendli",
|
||||
"GigaChat",
|
||||
"GPT4All",
|
||||
"GooglePalm",
|
||||
|
Reference in New Issue
Block a user