mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
Add input_type override (#14068)
Add option to override input_type for cohere's v3 embeddings models --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
aaabc1574f
commit
0f02081392
File diff suppressed because one or more lines are too long
@ -1,8 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@ -18,7 +17,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
cohere = CohereEmbeddings(
|
||||
model="embed-english-light-v3.0", cohere_api_key="my-api-key"
|
||||
model="embed-english-light-v3.0",
|
||||
cohere_api_key="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
@ -78,8 +78,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed(
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = self.client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
async def aembed(
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
@ -87,13 +109,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = self.client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type="search_document",
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
return self.embed(texts, input_type="search_document")
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
@ -104,13 +120,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type="search_document",
|
||||
truncate=self.truncate,
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings]
|
||||
return await self.aembed(texts, input_type="search_document")
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
@ -121,13 +131,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = self.client.embed(
|
||||
model=self.model,
|
||||
texts=[text],
|
||||
input_type="search_query",
|
||||
truncate=self.truncate,
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings][0]
|
||||
return self.embed([text], input_type="search_query")[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
@ -138,10 +142,4 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model,
|
||||
texts=[text],
|
||||
input_type="search_query",
|
||||
truncate=self.truncate,
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings][0]
|
||||
return (await self.aembed([text], input_type="search_query"))[0]
|
||||
|
Loading…
Reference in New Issue
Block a user