mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
CohereEmbeddings: Add max_retries and request_timeout (#12275)
Add max_retries and request_timeout to CohereEmbeddings, akin to how it works in OpenAIEmbeddings. Since the Cohere client already implements these parameters, we can simply pass them down. Uses parameters from these two cohere client objects: https://github.com/cohere-ai/cohere-python/blob/main/cohere/client.py https://github.com/cohere-ai/cohere-python/blob/main/cohere/client_async.py
This commit is contained in:
parent
7108084947
commit
c26ec7789f
@ -33,6 +33,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[float] = None
|
||||
"""Timeout in seconds for the Cohere API request."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -44,11 +49,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
max_retries = values.get("max_retries")
|
||||
request_timeout = values.get("request_timeout")
|
||||
|
||||
try:
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
values["client"] = cohere.Client(
|
||||
cohere_api_key, max_retries=max_retries, timeout=request_timeout
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key, max_retries=max_retries, timeout=request_timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import cohere python package. "
|
||||
|
Loading…
Reference in New Issue
Block a user