mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Description * Add _generate and _agenerate to support Fireworks batching. * Add stop words test cases * Opt out retry mechanism Issue - Not applicable Dependencies - None Tag maintainer - @baskaryan
This commit is contained in:
parent
3fbb2f3e52
commit
6ce276e099
@ -89,6 +89,7 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 20
|
||||
use_retry: bool = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
@ -134,7 +135,11 @@ class ChatFireworks(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self,
|
||||
self.use_retry,
|
||||
run_manager=run_manager,
|
||||
stop=stop,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@ -152,7 +157,7 @@ class ChatFireworks(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@ -195,7 +200,7 @@ class ChatFireworks(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for chunk in completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
@ -224,7 +229,7 @@ class ChatFireworks(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for chunk in await acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
@ -238,8 +243,20 @@ class ChatFireworks(BaseChatModel):
|
||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
|
||||
|
||||
def conditional_decorator(
|
||||
condition: bool, decorator: Callable[[Any], Any]
|
||||
) -> Callable[[Any], Any]:
|
||||
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
if condition:
|
||||
return decorator(func)
|
||||
return func
|
||||
|
||||
return actual_decorator
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -249,7 +266,7 @@ def completion_with_retry(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
@ -260,6 +277,7 @@ def completion_with_retry(
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -269,7 +287,7 @@ async def acompletion_with_retry(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
@ -280,6 +298,7 @@ async def acompletion_with_retry(
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -289,7 +308,7 @@ async def acompletion_with_retry_streaming(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
@ -309,6 +328,8 @@ def _create_retry_decorator(
|
||||
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.InternalServerError,
|
||||
fireworks.client.error.BadGatewayError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
|
@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.output import Generation, GenerationChunk, LLMResult
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
|
||||
|
||||
@ -23,7 +25,7 @@ def _stream_response_to_generation_chunk(
|
||||
)
|
||||
|
||||
|
||||
class Fireworks(LLM):
|
||||
class Fireworks(BaseLLM):
|
||||
"""Fireworks models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
@ -36,6 +38,8 @@ class Fireworks(LLM):
|
||||
)
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 20
|
||||
batch_size: int = 20
|
||||
use_retry: bool = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
@ -66,43 +70,92 @@ class Fireworks(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fireworks"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params: dict = {
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint with k unique prompts.
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The full LLM output.
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
sub_prompts = self.get_batch_prompts(prompts)
|
||||
choices = []
|
||||
for _prompts in sub_prompts:
|
||||
response = completion_with_retry_batching(
|
||||
self,
|
||||
self.use_retry,
|
||||
prompt=_prompts,
|
||||
run_manager=run_manager,
|
||||
stop=stop,
|
||||
**params,
|
||||
)
|
||||
choices.extend(response)
|
||||
|
||||
return response.choices[0].text
|
||||
return self.create_llm_result(choices, prompts)
|
||||
|
||||
async def _acall(
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
sub_prompts = self.get_batch_prompts(prompts)
|
||||
choices = []
|
||||
for _prompts in sub_prompts:
|
||||
response = await acompletion_with_retry_batching(
|
||||
self,
|
||||
self.use_retry,
|
||||
prompt=_prompts,
|
||||
run_manager=run_manager,
|
||||
stop=stop,
|
||||
**params,
|
||||
)
|
||||
choices.extend(response)
|
||||
|
||||
return response.choices[0].text
|
||||
return self.create_llm_result(choices, prompts)
|
||||
|
||||
def get_batch_prompts(
|
||||
self,
|
||||
prompts: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
sub_prompts = [
|
||||
prompts[i : i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i : (i + 1)]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice.__dict__["choices"][0].text,
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"model": self.model}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -118,7 +171,7 @@ class Fireworks(LLM):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for stream_resp in completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
@ -139,7 +192,7 @@ class Fireworks(LLM):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for stream_resp in await acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
@ -147,8 +200,20 @@ class Fireworks(LLM):
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
|
||||
def conditional_decorator(
|
||||
condition: bool, decorator: Callable[[Any], Any]
|
||||
) -> Callable[[Any], Any]:
|
||||
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
if condition:
|
||||
return decorator(func)
|
||||
return func
|
||||
|
||||
return actual_decorator
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: Fireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -158,7 +223,7 @@ def completion_with_retry(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.create(
|
||||
**kwargs,
|
||||
@ -169,6 +234,7 @@ def completion_with_retry(
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: Fireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -178,7 +244,7 @@ async def acompletion_with_retry(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
@ -187,8 +253,79 @@ async def acompletion_with_retry(
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def completion_with_retry_batching(
|
||||
llm: Fireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
import fireworks.client
|
||||
|
||||
prompt = kwargs["prompt"]
|
||||
del kwargs["prompt"]
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
def _completion_with_retry(prompt: str) -> Any:
|
||||
return fireworks.client.Completion.create(**kwargs, prompt=prompt)
|
||||
|
||||
def batch_sync_run() -> List:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(_completion_with_retry, prompt))
|
||||
return results
|
||||
|
||||
return batch_sync_run()
|
||||
|
||||
|
||||
async def acompletion_with_retry_batching(
|
||||
llm: Fireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
import fireworks.client
|
||||
|
||||
prompt = kwargs["prompt"]
|
||||
del kwargs["prompt"]
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(prompt: str) -> Any:
|
||||
return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)
|
||||
|
||||
def run_coroutine_in_new_loop(
|
||||
coroutine_func: Any, *args: Dict, **kwargs: Dict
|
||||
) -> Any:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
async def batch_sync_run() -> List:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(
|
||||
run_coroutine_in_new_loop,
|
||||
[_completion_with_retry] * len(prompt),
|
||||
prompt,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
return await batch_sync_run()
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: Fireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -198,7 +335,7 @@ async def acompletion_with_retry_streaming(
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
@ -219,6 +356,8 @@ def _create_retry_decorator(
|
||||
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.InternalServerError,
|
||||
fireworks.client.error.BadGatewayError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
|
@ -3,11 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.fireworks import ChatFireworks
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@ -72,6 +68,64 @@ def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
||||
assert llm_result.llm_output["model"] == chat.model
|
||||
|
||||
|
||||
def test_fireworks_invoke() -> None:
|
||||
"""Tests chat completion with invoke"""
|
||||
chat = ChatFireworks()
|
||||
result = chat.invoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(result.content, str)
|
||||
assert result.content[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_ainvoke() -> None:
|
||||
"""Tests chat completion with invoke"""
|
||||
chat = ChatFireworks()
|
||||
result = await chat.ainvoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(result.content, str)
|
||||
assert result.content[-1] == ","
|
||||
|
||||
|
||||
def test_fireworks_batch() -> None:
|
||||
"""Test batch tokens from ChatFireworks."""
|
||||
chat = ChatFireworks()
|
||||
result = chat.batch(
|
||||
[
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
],
|
||||
config={"max_concurrency": 5},
|
||||
stop=[","],
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
assert token.content[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_abatch() -> None:
|
||||
"""Test batch tokens from ChatFireworks."""
|
||||
chat = ChatFireworks()
|
||||
result = await chat.abatch(
|
||||
[
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
"What is the weather in Redwood City, CA today",
|
||||
],
|
||||
config={"max_concurrency": 5},
|
||||
stop=[","],
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
assert token.content[-1] == ","
|
||||
|
||||
|
||||
def test_fireworks_streaming() -> None:
|
||||
"""Test streaming tokens from Fireworks."""
|
||||
llm = ChatFireworks()
|
||||
@ -80,6 +134,17 @@ def test_fireworks_streaming() -> None:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_fireworks_streaming_stop_words() -> None:
|
||||
"""Test streaming tokens with stop words."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
last_token = ""
|
||||
for token in llm.stream("I'm Pickle Rick", stop=[","]):
|
||||
last_token = token.content
|
||||
assert isinstance(token.content, str)
|
||||
assert last_token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_fireworks_agenerate() -> None:
|
||||
"""Test ChatFireworks wrapper with generate."""
|
||||
@ -101,5 +166,10 @@ async def test_fireworks_astream() -> None:
|
||||
"""Test streaming tokens from Fireworks."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||
last_token = ""
|
||||
async for token in llm.astream(
|
||||
"Who's the best quarterback in the NFL?", stop=[","]
|
||||
):
|
||||
last_token = token.content
|
||||
assert isinstance(token.content, str)
|
||||
assert last_token[-1] == ","
|
||||
|
@ -16,7 +16,7 @@ from langchain.schema import LLMResult
|
||||
def test_fireworks_call() -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
llm = Fireworks()
|
||||
output = llm("Who's the best quarterback in the NFL?")
|
||||
output = llm("How is the weather in New York today?")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None:
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
def test_fireworks_invoke() -> None:
|
||||
"""Tests completion with invoke"""
|
||||
llm = Fireworks()
|
||||
output = llm.invoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(output, str)
|
||||
assert output[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_ainvoke() -> None:
|
||||
"""Tests completion with invoke"""
|
||||
llm = Fireworks()
|
||||
output = await llm.ainvoke("How is the weather in New York today?", stop=[","])
|
||||
assert isinstance(output, str)
|
||||
assert output[-1] == ","
|
||||
|
||||
|
||||
def test_fireworks_batch() -> None:
|
||||
"""Tests completion with invoke"""
|
||||
llm = Fireworks()
|
||||
output = llm.batch(
|
||||
[
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
],
|
||||
stop=[","],
|
||||
)
|
||||
for token in output:
|
||||
assert isinstance(token, str)
|
||||
assert token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_abatch() -> None:
|
||||
"""Tests completion with invoke"""
|
||||
llm = Fireworks()
|
||||
output = await llm.abatch(
|
||||
[
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
"How is the weather in New York today?",
|
||||
],
|
||||
stop=[","],
|
||||
)
|
||||
for token in output:
|
||||
assert isinstance(token, str)
|
||||
assert token[-1] == ","
|
||||
|
||||
|
||||
def test_fireworks_multiple_prompts() -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
llm = Fireworks()
|
||||
@ -60,13 +114,31 @@ def test_fireworks_streaming() -> None:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_fireworks_streaming_stop_words() -> None:
|
||||
"""Test stream completion with stop words."""
|
||||
llm = Fireworks()
|
||||
generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","])
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
last_token = ""
|
||||
for token in generator:
|
||||
last_token = token
|
||||
assert isinstance(token, str)
|
||||
assert last_token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_streaming_async() -> None:
|
||||
"""Test stream completion."""
|
||||
llm = Fireworks()
|
||||
|
||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||
last_token = ""
|
||||
async for token in llm.astream(
|
||||
"Who's the best quarterback in the NFL?", stop=[","]
|
||||
):
|
||||
last_token = token
|
||||
assert isinstance(token, str)
|
||||
assert last_token[-1] == ","
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user