mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
cohere retries (#5757)
…719) A minor update to retry Cohore API call in case of errors using tenacity as it is done for OpenAI LLMs. #### Who can review? @hwchase17, @agola11 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> <!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> Fixes # (issue) #### Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> --------- Co-authored-by: Sagar Sapkota <22609549+sagar-spkt@users.noreply.github.com>
This commit is contained in:
parent
5124c1e0d9
commit
98dd6d068a
@ -1,8 +1,17 @@
|
||||
"""Wrapper around Cohere APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@ -12,6 +21,33 @@ from langchain.utils import get_from_dict_or_env
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
|
||||
import cohere
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(retry_if_exception_type(cohere.error.CohereError)),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class Cohere(LLM):
|
||||
"""Wrapper around Cohere large language models.
|
||||
|
||||
@ -52,6 +88,9 @@ class Cohere(LLM):
|
||||
"""Specify how the client handles inputs longer than the maximum token
|
||||
length: Truncate from START, END or NONE"""
|
||||
|
||||
max_retries: int = 10
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
@ -129,7 +168,9 @@ class Cohere(LLM):
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
response = self.client.generate(model=self.model, prompt=prompt, **params)
|
||||
response = completion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
text = response.generations[0].text
|
||||
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
|
Loading…
Reference in New Issue
Block a user