community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)

- **Description:** Add support for cohere SDK v5 (keeps v4 backwards
compatibility)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
billytrend-cohere
2024-03-14 17:53:24 -05:00
committed by GitHub
parent 06165efb5b
commit 7253b816cc
6 changed files with 101 additions and 47 deletions

View File

@@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.cohere import _create_retry_decorator
class CohereEmbeddings(BaseModel, Embeddings):
"""Cohere embedding models.
@@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key: Optional[str] = None
max_retries: Optional[int] = 3
max_retries: int = 3
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
@@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
max_retries = values.get("max_retries")
request_timeout = values.get("request_timeout")
try:
@@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
client_name = values["user_agent"]
values["client"] = cohere.Client(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
@@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
)
return values
def embed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
return self.client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def aembed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
async def _embed_with_retry(**kwargs: Any) -> Any:
return await self.async_client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def embed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = self.client.embed(
embeddings = self.embed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,
@@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = (
await self.async_client.embed(
await self.aembed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,