From 52d0055a915e9d87f0175c0052a7754f1b73ccf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 6 Nov 2023 21:06:58 +0100 Subject: [PATCH] Add support of Cohere Embed v3 (#12940) Cohere released the new embedding API (Embed v3: https://txt.cohere.com/introducing-embed-v3/) that treats document and query embeddings differently. This PR updated the `CohereEmbeddings` to use them appropriately. It also works with the old models. --- libs/langchain/langchain/embeddings/cohere.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index e64019dd20f..0531390ce2e 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -87,7 +87,10 @@ class CohereEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ embeddings = self.client.embed( - model=self.model, texts=texts, truncate=self.truncate + model=self.model, + texts=texts, + input_type="search_document", + truncate=self.truncate, ).embeddings return [list(map(float, e)) for e in embeddings] @@ -101,7 +104,10 @@ class CohereEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ embeddings = await self.async_client.embed( - model=self.model, texts=texts, truncate=self.truncate + model=self.model, + texts=texts, + input_type="search_document", + truncate=self.truncate, ) return [list(map(float, e)) for e in embeddings.embeddings] @@ -114,7 +120,13 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return self.embed_documents([text])[0] + embeddings = self.client.embed( + model=self.model, + texts=[text], + input_type="search_query", + truncate=self.truncate, + ).embeddings + return [list(map(float, e)) for e in embeddings][0] async def aembed_query(self, text: str) -> List[float]: """Async call out to Cohere's embedding endpoint. @@ -125,5 +137,10 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embeddings = await self.aembed_documents([text]) - return embeddings[0] + embeddings = await self.async_client.embed( + model=self.model, + texts=[text], + input_type="search_query", + truncate=self.truncate, + ) + return [list(map(float, e)) for e in embeddings.embeddings][0]