Add input_type override (#14068)

Add option to override input_type for cohere's v3 embeddings models

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
billytrend-cohere 2023-12-04 20:10:24 +00:00 committed by GitHub
parent aaabc1574f
commit 0f02081392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 44 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,8 +1,7 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -18,7 +17,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
from langchain.embeddings import CohereEmbeddings from langchain.embeddings import CohereEmbeddings
cohere = CohereEmbeddings( cohere = CohereEmbeddings(
model="embed-english-light-v3.0", cohere_api_key="my-api-key" model="embed-english-light-v3.0",
cohere_api_key="my-api-key"
) )
""" """
@ -78,8 +78,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
) )
return values return values
def embed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = self.client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
async def aembed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = await self.async_client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Cohere's embedding endpoint. """Embed a list of document texts.
Args: Args:
texts: The list of texts to embed. texts: The list of texts to embed.
@ -87,13 +109,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = self.client.embed( return self.embed(texts, input_type="search_document")
model=self.model,
texts=texts,
input_type="search_document",
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Cohere's embedding endpoint. """Async call out to Cohere's embedding endpoint.
@ -104,13 +120,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = await self.async_client.embed( return await self.aembed(texts, input_type="search_document")
model=self.model,
texts=texts,
input_type="search_document",
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to Cohere's embedding endpoint. """Call out to Cohere's embedding endpoint.
@ -121,13 +131,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
embeddings = self.client.embed( return self.embed([text], input_type="search_query")[0]
model=self.model,
texts=[text],
input_type="search_query",
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings][0]
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Cohere's embedding endpoint. """Async call out to Cohere's embedding endpoint.
@ -138,10 +142,4 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
embeddings = await self.async_client.embed( return (await self.aembed([text], input_type="search_query"))[0]
model=self.model,
texts=[text],
input_type="search_query",
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings][0]