mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
Add input_type override (#14068)
Add option to override input_type for cohere's v3 embeddings models --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
aaabc1574f
commit
0f02081392
File diff suppressed because one or more lines are too long
@ -1,8 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain.schema.embeddings import Embeddings
|
||||||
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +17,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
from langchain.embeddings import CohereEmbeddings
|
from langchain.embeddings import CohereEmbeddings
|
||||||
cohere = CohereEmbeddings(
|
cohere = CohereEmbeddings(
|
||||||
model="embed-english-light-v3.0", cohere_api_key="my-api-key"
|
model="embed-english-light-v3.0",
|
||||||
|
cohere_api_key="my-api-key"
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -78,8 +78,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self, texts: List[str], *, input_type: Optional[str] = None
|
||||||
|
) -> List[List[float]]:
|
||||||
|
embeddings = self.client.embed(
|
||||||
|
model=self.model,
|
||||||
|
texts=texts,
|
||||||
|
input_type=input_type,
|
||||||
|
truncate=self.truncate,
|
||||||
|
).embeddings
|
||||||
|
return [list(map(float, e)) for e in embeddings]
|
||||||
|
|
||||||
|
async def aembed(
|
||||||
|
self, texts: List[str], *, input_type: Optional[str] = None
|
||||||
|
) -> List[List[float]]:
|
||||||
|
embeddings = await self.async_client.embed(
|
||||||
|
model=self.model,
|
||||||
|
texts=texts,
|
||||||
|
input_type=input_type,
|
||||||
|
truncate=self.truncate,
|
||||||
|
).embeddings
|
||||||
|
return [list(map(float, e)) for e in embeddings]
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Call out to Cohere's embedding endpoint.
|
"""Embed a list of document texts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
@ -87,13 +109,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = self.client.embed(
|
return self.embed(texts, input_type="search_document")
|
||||||
model=self.model,
|
|
||||||
texts=texts,
|
|
||||||
input_type="search_document",
|
|
||||||
truncate=self.truncate,
|
|
||||||
).embeddings
|
|
||||||
return [list(map(float, e)) for e in embeddings]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Async call out to Cohere's embedding endpoint.
|
"""Async call out to Cohere's embedding endpoint.
|
||||||
@ -104,13 +120,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = await self.async_client.embed(
|
return await self.aembed(texts, input_type="search_document")
|
||||||
model=self.model,
|
|
||||||
texts=texts,
|
|
||||||
input_type="search_document",
|
|
||||||
truncate=self.truncate,
|
|
||||||
)
|
|
||||||
return [list(map(float, e)) for e in embeddings.embeddings]
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to Cohere's embedding endpoint.
|
"""Call out to Cohere's embedding endpoint.
|
||||||
@ -121,13 +131,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
embeddings = self.client.embed(
|
return self.embed([text], input_type="search_query")[0]
|
||||||
model=self.model,
|
|
||||||
texts=[text],
|
|
||||||
input_type="search_query",
|
|
||||||
truncate=self.truncate,
|
|
||||||
).embeddings
|
|
||||||
return [list(map(float, e)) for e in embeddings][0]
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Async call out to Cohere's embedding endpoint.
|
"""Async call out to Cohere's embedding endpoint.
|
||||||
@ -138,10 +142,4 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
embeddings = await self.async_client.embed(
|
return (await self.aembed([text], input_type="search_query"))[0]
|
||||||
model=self.model,
|
|
||||||
texts=[text],
|
|
||||||
input_type="search_query",
|
|
||||||
truncate=self.truncate,
|
|
||||||
)
|
|
||||||
return [list(map(float, e)) for e in embeddings.embeddings][0]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user