From c26ec7789f528fb8200ef3bad499c5e1cc348b30 Mon Sep 17 00:00:00 2001 From: Johanna Appel Date: Wed, 25 Oct 2023 18:37:25 +0100 Subject: [PATCH] CohereEmbeddings: Add max_retries and request_timeout (#12275) Add max_retries and request_timeout to CohereEmbeddings, akin to how it works in OpenAIEmbeddings. Since the Cohere client already implements these parameters, we can simply pass them down. Uses parameters from these two cohere client objects: https://github.com/cohere-ai/cohere-python/blob/main/cohere/client.py https://github.com/cohere-ai/cohere-python/blob/main/cohere/client_async.py --- libs/langchain/langchain/embeddings/cohere.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index 1c908876f9d..1a6b3cdca0d 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -33,6 +33,11 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key: Optional[str] = None + max_retries: Optional[int] = None + """Maximum number of retries to make when generating.""" + request_timeout: Optional[float] = None + """Timeout in seconds for the Cohere API request.""" + class Config: """Configuration for this pydantic object.""" @@ -44,11 +49,18 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) + max_retries = values.get("max_retries") + request_timeout = values.get("request_timeout") + try: import cohere - values["client"] = cohere.Client(cohere_api_key) - values["async_client"] = cohere.AsyncClient(cohere_api_key) + values["client"] = cohere.Client( + cohere_api_key, max_retries=max_retries, timeout=request_timeout + ) + values["async_client"] = cohere.AsyncClient( + cohere_api_key, max_retries=max_retries, timeout=request_timeout + ) except ImportError: raise ValueError( "Could not import cohere python package. "