Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
0401eece1a rfc 2023-11-07 08:16:08 -08:00

View File

@@ -39,6 +39,10 @@ class CohereEmbeddings(BaseModel, Embeddings):
"""Timeout in seconds for the Cohere API request."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
embed_documents_input_type: str = "search_document"
""""""
embded_query_input_type: str = "search_query"
""""""
class Config:
"""Configuration for this pydantic object."""
@@ -89,7 +93,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
embeddings = self.client.embed(
model=self.model,
texts=texts,
input_type="search_document",
input_type=self.embed_documents_input_type,
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
@@ -106,7 +110,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
embeddings = await self.async_client.embed(
model=self.model,
texts=texts,
input_type="search_document",
input_type=self.embed_documents_input_type,
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings]
@@ -123,7 +127,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
embeddings = self.client.embed(
model=self.model,
texts=[text],
input_type="search_query",
input_type=self.embed_query_input_type,
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings][0]
@@ -140,7 +144,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
embeddings = await self.async_client.embed(
model=self.model,
texts=[text],
input_type="search_query",
input_type=self.embed_query_input_type,
truncate=self.truncate,
)
return [list(map(float, e)) for e in embeddings.embeddings][0]