mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)
- **Description:** Add support for cohere SDK v5 (keeps v4 backwards compatibility) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
06165efb5b
commit
7253b816cc
@@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.cohere import _create_retry_decorator
|
||||
|
||||
|
||||
class CohereEmbeddings(BaseModel, Embeddings):
|
||||
"""Cohere embedding models.
|
||||
@@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
max_retries: Optional[int] = 3
|
||||
max_retries: int = 3
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[float] = None
|
||||
"""Timeout in seconds for the Cohere API request."""
|
||||
@@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
max_retries = values.get("max_retries")
|
||||
request_timeout = values.get("request_timeout")
|
||||
|
||||
try:
|
||||
@@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
@@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embed call."""
|
||||
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.embed(**kwargs)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
def aembed_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embed call."""
|
||||
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
async def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
return await self.async_client.embed(**kwargs)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
def embed(
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = self.client.embed(
|
||||
embeddings = self.embed_with_retry(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
@@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = (
|
||||
await self.async_client.embed(
|
||||
await self.aembed_with_retry(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
|
Reference in New Issue
Block a user