mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 01:48:57 +00:00
Description * Add _generate and _agenerate to support Fireworks batching. * Add stop words test cases * Opt out retry mechanism Issue - Not applicable Dependencies - None Tag maintainer - @baskaryan
This commit is contained in:
parent
3fbb2f3e52
commit
6ce276e099
@ -89,6 +89,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
)
|
)
|
||||||
fireworks_api_key: Optional[str] = None
|
fireworks_api_key: Optional[str] = None
|
||||||
max_retries: int = 20
|
max_retries: int = 20
|
||||||
|
use_retry: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
@ -134,7 +135,11 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
response = completion_with_retry(
|
response = completion_with_retry(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self,
|
||||||
|
self.use_retry,
|
||||||
|
run_manager=run_manager,
|
||||||
|
stop=stop,
|
||||||
|
**params,
|
||||||
)
|
)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@ -152,7 +157,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
response = await acompletion_with_retry(
|
response = await acompletion_with_retry(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||||
)
|
)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@ -195,7 +200,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
for chunk in completion_with_retry(
|
for chunk in completion_with_retry(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||||
):
|
):
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||||
@ -224,7 +229,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
async for chunk in await acompletion_with_retry_streaming(
|
async for chunk in await acompletion_with_retry_streaming(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||||
):
|
):
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||||
@ -238,8 +243,20 @@ class ChatFireworks(BaseChatModel):
|
|||||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def conditional_decorator(
|
||||||
|
condition: bool, decorator: Callable[[Any], Any]
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||||
|
if condition:
|
||||||
|
return decorator(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return actual_decorator
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry(
|
||||||
llm: ChatFireworks,
|
llm: ChatFireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -249,7 +266,7 @@ def completion_with_retry(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return fireworks.client.ChatCompletion.create(
|
return fireworks.client.ChatCompletion.create(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -260,6 +277,7 @@ def completion_with_retry(
|
|||||||
|
|
||||||
async def acompletion_with_retry(
|
async def acompletion_with_retry(
|
||||||
llm: ChatFireworks,
|
llm: ChatFireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -269,7 +287,7 @@ async def acompletion_with_retry(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return await fireworks.client.ChatCompletion.acreate(
|
return await fireworks.client.ChatCompletion.acreate(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -280,6 +298,7 @@ async def acompletion_with_retry(
|
|||||||
|
|
||||||
async def acompletion_with_retry_streaming(
|
async def acompletion_with_retry_streaming(
|
||||||
llm: ChatFireworks,
|
llm: ChatFireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -289,7 +308,7 @@ async def acompletion_with_retry_streaming(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return fireworks.client.ChatCompletion.acreate(
|
return fireworks.client.ChatCompletion.acreate(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -309,6 +328,8 @@ def _create_retry_decorator(
|
|||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
fireworks.client.error.RateLimitError,
|
fireworks.client.error.RateLimitError,
|
||||||
|
fireworks.client.error.InternalServerError,
|
||||||
|
fireworks.client.error.BadGatewayError,
|
||||||
fireworks.client.error.ServiceUnavailableError,
|
fireworks.client.error.ServiceUnavailableError,
|
||||||
]
|
]
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.pydantic_v1 import Field, root_validator
|
from langchain.pydantic_v1 import Field, root_validator
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import Generation, GenerationChunk, LLMResult
|
||||||
from langchain.utils.env import get_from_dict_or_env
|
from langchain.utils.env import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ def _stream_response_to_generation_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fireworks(LLM):
|
class Fireworks(BaseLLM):
|
||||||
"""Fireworks models."""
|
"""Fireworks models."""
|
||||||
|
|
||||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||||
@ -36,6 +38,8 @@ class Fireworks(LLM):
|
|||||||
)
|
)
|
||||||
fireworks_api_key: Optional[str] = None
|
fireworks_api_key: Optional[str] = None
|
||||||
max_retries: int = 20
|
max_retries: int = 20
|
||||||
|
batch_size: int = 20
|
||||||
|
use_retry: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
@ -66,43 +70,92 @@ class Fireworks(LLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "fireworks"
|
return "fireworks"
|
||||||
|
|
||||||
def _call(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Call out to Fireworks endpoint with k unique prompts.
|
||||||
params: dict = {
|
Args:
|
||||||
|
prompts: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
Returns:
|
||||||
|
The full LLM output.
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
response = completion_with_retry(
|
sub_prompts = self.get_batch_prompts(prompts)
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
choices = []
|
||||||
)
|
for _prompts in sub_prompts:
|
||||||
|
response = completion_with_retry_batching(
|
||||||
|
self,
|
||||||
|
self.use_retry,
|
||||||
|
prompt=_prompts,
|
||||||
|
run_manager=run_manager,
|
||||||
|
stop=stop,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
choices.extend(response)
|
||||||
|
|
||||||
return response.choices[0].text
|
return self.create_llm_result(choices, prompts)
|
||||||
|
|
||||||
async def _acall(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
response = await acompletion_with_retry(
|
sub_prompts = self.get_batch_prompts(prompts)
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
choices = []
|
||||||
)
|
for _prompts in sub_prompts:
|
||||||
|
response = await acompletion_with_retry_batching(
|
||||||
|
self,
|
||||||
|
self.use_retry,
|
||||||
|
prompt=_prompts,
|
||||||
|
run_manager=run_manager,
|
||||||
|
stop=stop,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
choices.extend(response)
|
||||||
|
|
||||||
return response.choices[0].text
|
return self.create_llm_result(choices, prompts)
|
||||||
|
|
||||||
|
def get_batch_prompts(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""Get the sub prompts for llm call."""
|
||||||
|
sub_prompts = [
|
||||||
|
prompts[i : i + self.batch_size]
|
||||||
|
for i in range(0, len(prompts), self.batch_size)
|
||||||
|
]
|
||||||
|
return sub_prompts
|
||||||
|
|
||||||
|
def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
|
||||||
|
"""Create the LLMResult from the choices and prompts."""
|
||||||
|
generations = []
|
||||||
|
for i, _ in enumerate(prompts):
|
||||||
|
sub_choices = choices[i : (i + 1)]
|
||||||
|
generations.append(
|
||||||
|
[
|
||||||
|
Generation(
|
||||||
|
text=choice.__dict__["choices"][0].text,
|
||||||
|
)
|
||||||
|
for choice in sub_choices
|
||||||
|
]
|
||||||
|
)
|
||||||
|
llm_output = {"model": self.model}
|
||||||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -118,7 +171,7 @@ class Fireworks(LLM):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
for stream_resp in completion_with_retry(
|
for stream_resp in completion_with_retry(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||||
):
|
):
|
||||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -139,7 +192,7 @@ class Fireworks(LLM):
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
async for stream_resp in await acompletion_with_retry_streaming(
|
async for stream_resp in await acompletion_with_retry_streaming(
|
||||||
self, run_manager=run_manager, stop=stop, **params
|
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||||
):
|
):
|
||||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -147,8 +200,20 @@ class Fireworks(LLM):
|
|||||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def conditional_decorator(
|
||||||
|
condition: bool, decorator: Callable[[Any], Any]
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||||
|
if condition:
|
||||||
|
return decorator(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return actual_decorator
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry(
|
||||||
llm: Fireworks,
|
llm: Fireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -158,7 +223,7 @@ def completion_with_retry(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return fireworks.client.Completion.create(
|
return fireworks.client.Completion.create(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -169,6 +234,7 @@ def completion_with_retry(
|
|||||||
|
|
||||||
async def acompletion_with_retry(
|
async def acompletion_with_retry(
|
||||||
llm: Fireworks,
|
llm: Fireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -178,7 +244,7 @@ async def acompletion_with_retry(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return await fireworks.client.Completion.acreate(
|
return await fireworks.client.Completion.acreate(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -187,8 +253,79 @@ async def acompletion_with_retry(
|
|||||||
return await _completion_with_retry(**kwargs)
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry_batching(
|
||||||
|
llm: Fireworks,
|
||||||
|
use_retry: bool,
|
||||||
|
*,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
import fireworks.client
|
||||||
|
|
||||||
|
prompt = kwargs["prompt"]
|
||||||
|
del kwargs["prompt"]
|
||||||
|
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
|
def _completion_with_retry(prompt: str) -> Any:
|
||||||
|
return fireworks.client.Completion.create(**kwargs, prompt=prompt)
|
||||||
|
|
||||||
|
def batch_sync_run() -> List:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(executor.map(_completion_with_retry, prompt))
|
||||||
|
return results
|
||||||
|
|
||||||
|
return batch_sync_run()
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry_batching(
|
||||||
|
llm: Fireworks,
|
||||||
|
use_retry: bool,
|
||||||
|
*,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
import fireworks.client
|
||||||
|
|
||||||
|
prompt = kwargs["prompt"]
|
||||||
|
del kwargs["prompt"]
|
||||||
|
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
|
async def _completion_with_retry(prompt: str) -> Any:
|
||||||
|
return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)
|
||||||
|
|
||||||
|
def run_coroutine_in_new_loop(
|
||||||
|
coroutine_func: Any, *args: Dict, **kwargs: Dict
|
||||||
|
) -> Any:
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
|
||||||
|
async def batch_sync_run() -> List:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(
|
||||||
|
run_coroutine_in_new_loop,
|
||||||
|
[_completion_with_retry] * len(prompt),
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
return await batch_sync_run()
|
||||||
|
|
||||||
|
|
||||||
async def acompletion_with_retry_streaming(
|
async def acompletion_with_retry_streaming(
|
||||||
llm: Fireworks,
|
llm: Fireworks,
|
||||||
|
use_retry: bool,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -198,7 +335,7 @@ async def acompletion_with_retry_streaming(
|
|||||||
|
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@conditional_decorator(use_retry, retry_decorator)
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
return fireworks.client.Completion.acreate(
|
return fireworks.client.Completion.acreate(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -219,6 +356,8 @@ def _create_retry_decorator(
|
|||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
fireworks.client.error.RateLimitError,
|
fireworks.client.error.RateLimitError,
|
||||||
|
fireworks.client.error.InternalServerError,
|
||||||
|
fireworks.client.error.BadGatewayError,
|
||||||
fireworks.client.error.ServiceUnavailableError,
|
fireworks.client.error.ServiceUnavailableError,
|
||||||
]
|
]
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
|
@ -3,11 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models.fireworks import ChatFireworks
|
from langchain.chat_models.fireworks import ChatFireworks
|
||||||
from langchain.schema import (
|
from langchain.schema import ChatGeneration, ChatResult, LLMResult
|
||||||
ChatGeneration,
|
|
||||||
ChatResult,
|
|
||||||
LLMResult,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +68,64 @@ def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
|||||||
assert llm_result.llm_output["model"] == chat.model
|
assert llm_result.llm_output["model"] == chat.model
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_invoke() -> None:
|
||||||
|
"""Tests chat completion with invoke"""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
result = chat.invoke("How is the weather in New York today?", stop=[","])
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
assert result.content[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_ainvoke() -> None:
|
||||||
|
"""Tests chat completion with invoke"""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
result = await chat.ainvoke("How is the weather in New York today?", stop=[","])
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
assert result.content[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_batch() -> None:
|
||||||
|
"""Test batch tokens from ChatFireworks."""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
result = chat.batch(
|
||||||
|
[
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
],
|
||||||
|
config={"max_concurrency": 5},
|
||||||
|
stop=[","],
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
assert token.content[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_abatch() -> None:
|
||||||
|
"""Test batch tokens from ChatFireworks."""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
result = await chat.abatch(
|
||||||
|
[
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
"What is the weather in Redwood City, CA today",
|
||||||
|
],
|
||||||
|
config={"max_concurrency": 5},
|
||||||
|
stop=[","],
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
assert token.content[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
def test_fireworks_streaming() -> None:
|
def test_fireworks_streaming() -> None:
|
||||||
"""Test streaming tokens from Fireworks."""
|
"""Test streaming tokens from Fireworks."""
|
||||||
llm = ChatFireworks()
|
llm = ChatFireworks()
|
||||||
@ -80,6 +134,17 @@ def test_fireworks_streaming() -> None:
|
|||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_streaming_stop_words() -> None:
|
||||||
|
"""Test streaming tokens with stop words."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
last_token = ""
|
||||||
|
for token in llm.stream("I'm Pickle Rick", stop=[","]):
|
||||||
|
last_token = token.content
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
assert last_token[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_fireworks_agenerate() -> None:
|
async def test_chat_fireworks_agenerate() -> None:
|
||||||
"""Test ChatFireworks wrapper with generate."""
|
"""Test ChatFireworks wrapper with generate."""
|
||||||
@ -101,5 +166,10 @@ async def test_fireworks_astream() -> None:
|
|||||||
"""Test streaming tokens from Fireworks."""
|
"""Test streaming tokens from Fireworks."""
|
||||||
llm = ChatFireworks()
|
llm = ChatFireworks()
|
||||||
|
|
||||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
last_token = ""
|
||||||
|
async for token in llm.astream(
|
||||||
|
"Who's the best quarterback in the NFL?", stop=[","]
|
||||||
|
):
|
||||||
|
last_token = token.content
|
||||||
assert isinstance(token.content, str)
|
assert isinstance(token.content, str)
|
||||||
|
assert last_token[-1] == ","
|
||||||
|
@ -16,7 +16,7 @@ from langchain.schema import LLMResult
|
|||||||
def test_fireworks_call() -> None:
|
def test_fireworks_call() -> None:
|
||||||
"""Test valid call to fireworks."""
|
"""Test valid call to fireworks."""
|
||||||
llm = Fireworks()
|
llm = Fireworks()
|
||||||
output = llm("Who's the best quarterback in the NFL?")
|
output = llm("How is the weather in New York today?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None:
|
|||||||
assert llm.model == "foo"
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_invoke() -> None:
|
||||||
|
"""Tests completion with invoke"""
|
||||||
|
llm = Fireworks()
|
||||||
|
output = llm.invoke("How is the weather in New York today?", stop=[","])
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert output[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_ainvoke() -> None:
|
||||||
|
"""Tests completion with invoke"""
|
||||||
|
llm = Fireworks()
|
||||||
|
output = await llm.ainvoke("How is the weather in New York today?", stop=[","])
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert output[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_batch() -> None:
|
||||||
|
"""Tests completion with invoke"""
|
||||||
|
llm = Fireworks()
|
||||||
|
output = llm.batch(
|
||||||
|
[
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
],
|
||||||
|
stop=[","],
|
||||||
|
)
|
||||||
|
for token in output:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert token[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_abatch() -> None:
|
||||||
|
"""Tests completion with invoke"""
|
||||||
|
llm = Fireworks()
|
||||||
|
output = await llm.abatch(
|
||||||
|
[
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
"How is the weather in New York today?",
|
||||||
|
],
|
||||||
|
stop=[","],
|
||||||
|
)
|
||||||
|
for token in output:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert token[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
def test_fireworks_multiple_prompts() -> None:
|
def test_fireworks_multiple_prompts() -> None:
|
||||||
"""Test completion with multiple prompts."""
|
"""Test completion with multiple prompts."""
|
||||||
llm = Fireworks()
|
llm = Fireworks()
|
||||||
@ -60,13 +114,31 @@ def test_fireworks_streaming() -> None:
|
|||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_streaming_stop_words() -> None:
|
||||||
|
"""Test stream completion with stop words."""
|
||||||
|
llm = Fireworks()
|
||||||
|
generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","])
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
|
last_token = ""
|
||||||
|
for token in generator:
|
||||||
|
last_token = token
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert last_token[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fireworks_streaming_async() -> None:
|
async def test_fireworks_streaming_async() -> None:
|
||||||
"""Test stream completion."""
|
"""Test stream completion."""
|
||||||
llm = Fireworks()
|
llm = Fireworks()
|
||||||
|
|
||||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
last_token = ""
|
||||||
|
async for token in llm.astream(
|
||||||
|
"Who's the best quarterback in the NFL?", stop=[","]
|
||||||
|
):
|
||||||
|
last_token = token
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
|
assert last_token[-1] == ","
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user