mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
community[patch]: Fixed the 'aembed' method of 'CohereEmbeddings'. (#16497)
**Description:** - The existing code was trying to find a `.embeddings` property on the `Coroutine` returned by calling `cohere.async_client.embed`. - Instead, the `.embeddings` property is present on the value returned by the `Coroutine`. - Also, it seems that the original cohere client expects a value of `max_retries` to not be `None`. Hence, setting the default value of `max_retries` to `3`. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
9f1cbbc6ed
commit
37e1275f9e
@ -34,7 +34,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
max_retries: Optional[int] = None
|
||||
max_retries: Optional[int] = 3
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[float] = None
|
||||
"""Timeout in seconds for the Cohere API request."""
|
||||
@ -92,11 +92,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
async def aembed(
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
truncate=self.truncate,
|
||||
embeddings = (
|
||||
await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
truncate=self.truncate,
|
||||
)
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user