diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index e64019dd20f..0531390ce2e 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -87,7 +87,10 @@ class CohereEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ embeddings = self.client.embed( - model=self.model, texts=texts, truncate=self.truncate + model=self.model, + texts=texts, + input_type="search_document", + truncate=self.truncate, ).embeddings return [list(map(float, e)) for e in embeddings] @@ -101,7 +104,10 @@ class CohereEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ embeddings = await self.async_client.embed( - model=self.model, texts=texts, truncate=self.truncate + model=self.model, + texts=texts, + input_type="search_document", + truncate=self.truncate, ) return [list(map(float, e)) for e in embeddings.embeddings] @@ -114,7 +120,13 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return self.embed_documents([text])[0] + embeddings = self.client.embed( + model=self.model, + texts=[text], + input_type="search_query", + truncate=self.truncate, + ).embeddings + return [list(map(float, e)) for e in embeddings][0] async def aembed_query(self, text: str) -> List[float]: """Async call out to Cohere's embedding endpoint. @@ -125,5 +137,10 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embeddings = await self.aembed_documents([text]) - return embeddings[0] + embeddings = await self.async_client.embed( + model=self.model, + texts=[text], + input_type="search_query", + truncate=self.truncate, + ) + return [list(map(float, e)) for e in embeddings.embeddings][0]