From ebea40ce862f60af1961e78e56f6fce6112e92b8 Mon Sep 17 00:00:00 2001 From: Johanna Appel Date: Wed, 1 Feb 2023 16:09:03 +0100 Subject: [PATCH] Add 'truncate' parameter for CohereEmbeddings (#798) Currently, the 'truncate' parameter of the cohere API is not supported. This means that by default, if trying to generate and embedding that is too big, the call will just fail with an error (which is frustrating if using this embedding source e.g. with GPT-Index, because it's hard to handle it properly when generating a lot of embeddings). With the parameter, one can decide to either truncate the START or END of the text to fit the max token length and still generate an embedding without throwing the error. In this PR, I added this parameter to the class. _Arguably, there should be a better way to handle this error, e.g. by optionally calling a function or so that gets triggered when the token limit is reached and can split the document or some such. Especially in the use case with GPT-Index, its often hard to estimate the token counts for each document and I'd rather sort out the troublemakers or simply split them than interrupting the whole execution. Thoughts?_ --------- Co-authored-by: Harrison Chase --- langchain/embeddings/cohere.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index b55e4aaafae..de3648a906d 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -25,6 +25,9 @@ class CohereEmbeddings(BaseModel, Embeddings): model: str = "large" """Model name to use.""" + truncate: str = "NONE" + """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" + cohere_api_key: Optional[str] = None class Config: @@ -58,7 +61,9 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = self.client.embed(model=self.model, texts=texts).embeddings + embeddings = self.client.embed( + model=self.model, texts=texts, truncate=self.truncate + ).embeddings return embeddings def embed_query(self, text: str) -> List[float]: @@ -70,5 +75,7 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embedding = self.client.embed(model=self.model, texts=[text]).embeddings[0] + embedding = self.client.embed( + model=self.model, texts=[text], truncate=self.truncate + ).embeddings[0] return embedding