mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 13:19:22 +00:00
Compare commits
1 Commits
sr/fix-too
...
dev2049/em
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
863baba3f0 |
@@ -42,9 +42,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.output_keys
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call the base embeddings."""
|
||||
return self.base_embeddings.embed_documents(texts)
|
||||
return self.base_embeddings.embed_texts(texts)
|
||||
|
||||
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
||||
"""Combine embeddings into final embeddings."""
|
||||
@@ -55,7 +55,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
documents = [generation.text for generation in result.generations[0]]
|
||||
embeddings = self.embed_documents(documents)
|
||||
embeddings = self.embed_texts(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
def _call(
|
||||
|
||||
@@ -55,9 +55,7 @@ def _get_embeddings_from_stateful_docs(
|
||||
if len(documents) and "embedded_doc" in documents[0].state:
|
||||
embedded_documents = [doc.state["embedded_doc"] for doc in documents]
|
||||
else:
|
||||
embedded_documents = embeddings.embed_documents(
|
||||
[d.page_content for d in documents]
|
||||
)
|
||||
embedded_documents = embeddings.embed_texts([d.page_content for d in documents])
|
||||
for doc, embedding in zip(documents, embedded_documents):
|
||||
doc.state["embedded_doc"] = embedding
|
||||
return embedded_documents
|
||||
|
||||
@@ -65,7 +65,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
|
||||
values["client"] = Client(token=aleph_alpha_api_key)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's asymmetric Document endpoint.
|
||||
|
||||
Args:
|
||||
@@ -186,7 +186,7 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
|
||||
|
||||
return query_response.embedding
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's Document endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,9 +6,13 @@ from typing import List
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
"""DEPRECATED. Kept for backwards compatibility."""
|
||||
return self.embed_texts(texts)
|
||||
|
||||
@abstractmethod
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search texts."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@@ -54,7 +54,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -12,7 +12,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
def _get_embedding(self) -> List[float]:
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@@ -77,7 +77,7 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@@ -60,7 +60,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -135,7 +135,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -77,7 +77,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -101,5 +101,5 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
response = self.embed_documents([text])[0]
|
||||
response = self.embed_texts([text])[0]
|
||||
return response
|
||||
|
||||
@@ -71,7 +71,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
|
||||
return self.client.post(on="/encode", **payload)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
@@ -99,7 +99,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using the Llama model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -271,7 +271,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
headers=self.headers,
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(
|
||||
def embed_texts(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
@@ -164,9 +164,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return self.content_handler.transform_output(response["Body"])
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 64
|
||||
) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str], chunk_size: int = 64) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a SageMaker Inference Endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -72,7 +72,7 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -140,7 +140,7 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
|
||||
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
|
||||
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -43,7 +43,7 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -49,7 +49,7 @@ def create_index(
|
||||
]
|
||||
|
||||
# create dense vectors
|
||||
dense_embeds = embeddings.embed_documents(context_batch)
|
||||
dense_embeds = embeddings.embed_texts(context_batch)
|
||||
# create sparse vectors
|
||||
sparse_embeds = sparse_encoder.encode_documents(context_batch)
|
||||
for s in sparse_embeds:
|
||||
|
||||
@@ -202,7 +202,7 @@ class AnalyticDB(VectorStore):
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
embeddings = self.embedding_function.embed_texts(list(texts))
|
||||
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
@@ -351,7 +351,7 @@ class Annoy(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
index = Annoy.from_texts(texts, embeddings)
|
||||
"""
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
return cls.__from(
|
||||
texts, embeddings, embedding, metadatas, metric, trees, n_jobs, **kwargs
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AtlasDB(VectorStore):
|
||||
|
||||
# Embedding upload case
|
||||
if self._embedding_function is not None:
|
||||
_embeddings = self._embedding_function.embed_documents(texts)
|
||||
_embeddings = self._embedding_function.embed_texts(texts)
|
||||
embeddings = np.stack(_embeddings)
|
||||
if metadatas is None:
|
||||
data = [
|
||||
@@ -194,7 +194,7 @@ class AtlasDB(VectorStore):
|
||||
"AtlasDB requires an embedding_function for text similarity search!"
|
||||
)
|
||||
|
||||
_embedding = self._embedding_function.embed_documents([query])[0]
|
||||
_embedding = self._embedding_function.embed_texts([query])[0]
|
||||
embedding = np.array(_embedding).reshape(1, -1)
|
||||
with self.project.wait_for_project_lock():
|
||||
neighbors, _ = self.project.projections[0].vector_search(
|
||||
|
||||
@@ -92,7 +92,7 @@ class Chroma(VectorStore):
|
||||
self._persist_directory = persist_directory
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function.embed_documents
|
||||
embedding_function=self._embedding_function.embed_texts
|
||||
if self._embedding_function is not None
|
||||
else None,
|
||||
metadata=collection_metadata,
|
||||
@@ -156,7 +156,7 @@ class Chroma(VectorStore):
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
embeddings = None
|
||||
if self._embedding_function is not None:
|
||||
embeddings = self._embedding_function.embed_documents(list(texts))
|
||||
embeddings = self._embedding_function.embed_texts(list(texts))
|
||||
self._collection.add(
|
||||
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
|
||||
)
|
||||
|
||||
@@ -224,7 +224,7 @@ class DeepLake(VectorStore):
|
||||
embeds: Sequence[Optional[np.ndarray]] = []
|
||||
|
||||
if self._embedding_function is not None:
|
||||
embeddings = self._embedding_function.embed_documents(text_list)
|
||||
embeddings = self._embedding_function.embed_texts(text_list)
|
||||
embeds = [np.array(e, dtype=np.float32) for e in embeddings]
|
||||
else:
|
||||
embeds = [None] * len(text_list)
|
||||
|
||||
@@ -77,7 +77,7 @@ class DocArrayIndex(VectorStore, ABC):
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
ids: List[str] = []
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
embeddings = self.embedding.embed_texts(list(texts))
|
||||
for i, (t, e) in enumerate(zip(texts, embeddings)):
|
||||
m = metadatas[i] if metadatas else {}
|
||||
doc = self.doc_cls(text=t, embedding=e, metadata=m)
|
||||
|
||||
@@ -167,7 +167,7 @@ class ElasticVectorSearch(VectorStore, ABC):
|
||||
)
|
||||
requests = []
|
||||
ids = []
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
embeddings = self.embedding.embed_texts(list(texts))
|
||||
dim = len(embeddings[0])
|
||||
mapping = _default_text_mapping(dim)
|
||||
|
||||
|
||||
@@ -417,7 +417,7 @@ class FAISS(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
|
||||
@@ -71,7 +71,7 @@ class LanceDB(VectorStore):
|
||||
# Embed texts and create documents
|
||||
docs = []
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
embeddings = self._embedding.embed_documents(list(texts))
|
||||
embeddings = self._embedding.embed_texts(list(texts))
|
||||
for idx, text in enumerate(texts):
|
||||
embedding = embeddings[idx]
|
||||
metadata = metadatas[idx] if metadatas else {}
|
||||
|
||||
@@ -397,7 +397,7 @@ class Milvus(VectorStore):
|
||||
texts = list(texts)
|
||||
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
embeddings = self.embedding_func.embed_texts(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
text_field: Document field the text of the document is stored in. Defaults
|
||||
to "text".
|
||||
"""
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
embeddings = self.embedding_function.embed_texts(list(texts))
|
||||
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||
dim = len(embeddings[0])
|
||||
@@ -560,7 +560,7 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
"ef_construction",
|
||||
"m",
|
||||
]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
||||
dim = len(embeddings[0])
|
||||
# Get the index name from either from kwargs or ENV Variable
|
||||
|
||||
@@ -212,7 +212,7 @@ class PGVector(VectorStore):
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
embeddings = self.embedding_function.embed_texts(list(texts))
|
||||
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
@@ -229,7 +229,7 @@ class Pinecone(VectorStore):
|
||||
else:
|
||||
ids_batch = [str(uuid.uuid4()) for n in range(i, i_end)]
|
||||
# create embeddings
|
||||
embeds = embedding.embed_documents(lines_batch)
|
||||
embeds = embedding.embed_texts(lines_batch)
|
||||
# prep metadata and upsert batch
|
||||
if metadatas:
|
||||
metadata = metadatas[i:i_end]
|
||||
|
||||
@@ -139,7 +139,7 @@ class Qdrant(VectorStore):
|
||||
List of floats representing the texts embedding.
|
||||
"""
|
||||
if self.embeddings is not None:
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
embeddings = self.embeddings.embed_texts(list(texts))
|
||||
if hasattr(embeddings, "tolist"):
|
||||
embeddings = embeddings.tolist()
|
||||
elif self._embeddings_function is not None:
|
||||
@@ -388,7 +388,7 @@ class Qdrant(VectorStore):
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
# Just do a single quick embedding to get vector size
|
||||
partial_embeddings = embedding.embed_documents(texts[:1])
|
||||
partial_embeddings = embedding.embed_texts(texts[:1])
|
||||
vector_size = len(partial_embeddings[0])
|
||||
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
@@ -418,7 +418,7 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
|
||||
# Now generate the embeddings for all the texts
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -408,7 +408,7 @@ class Redis(VectorStore):
|
||||
)
|
||||
|
||||
# Create embeddings over documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
|
||||
# Create the search index
|
||||
instance._create_index(dim=len(embeddings[0]), distance_metric=distance_metric)
|
||||
|
||||
@@ -74,7 +74,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
) -> List[str]:
|
||||
docs = self._texts_to_documents(texts, metadatas)
|
||||
|
||||
vectors = self._embedding.embed_documents(list(texts))
|
||||
vectors = self._embedding.embed_texts(list(texts))
|
||||
return self.add_vectors(vectors, docs)
|
||||
|
||||
@classmethod
|
||||
@@ -96,7 +96,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
if not table_name:
|
||||
raise ValueError("Supabase document table_name is required.")
|
||||
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
docs = cls._texts_to_documents(texts, metadatas)
|
||||
_ids = cls._add_vectors(client, table_name, embeddings, docs)
|
||||
|
||||
@@ -115,7 +115,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
vectors = self._embedding.embed_documents([query])
|
||||
vectors = self._embedding.embed_texts([query])
|
||||
return self.similarity_search_by_vector(vectors[0], k)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
@@ -130,7 +130,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
def similarity_search_with_relevance_scores(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
vectors = self._embedding.embed_documents([query])
|
||||
vectors = self._embedding.embed_texts([query])
|
||||
return self.similarity_search_by_vector_with_relevance_scores(vectors[0], k)
|
||||
|
||||
def similarity_search_by_vector_with_relevance_scores(
|
||||
@@ -328,7 +328,7 @@ class SupabaseVectorStore(VectorStore):
|
||||
END;
|
||||
$$;```
|
||||
"""
|
||||
embedding = self._embedding.embed_documents([query])
|
||||
embedding = self._embedding.embed_texts([query])
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding[0], k, fetch_k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
@@ -82,7 +82,7 @@ class Tair(VectorStore):
|
||||
keys = kwargs.get("keys", None)
|
||||
# Write data to tair
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
embeddings = self.embedding_function.embed_texts(list(texts))
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided key otherwise use default key
|
||||
key = keys[i] if keys else _uuid_key()
|
||||
@@ -191,7 +191,7 @@ class Tair(VectorStore):
|
||||
raise ValueError(f"tair failed to connect: {e}")
|
||||
|
||||
# Create embeddings for documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
|
||||
tair_vector_store.create_index_if_not_exist(
|
||||
len(embeddings[0]),
|
||||
|
||||
@@ -142,7 +142,7 @@ class Weaviate(VectorStore):
|
||||
)
|
||||
|
||||
if self._embedding is not None:
|
||||
vector = self._embedding.embed_documents([text])[0]
|
||||
vector = self._embedding.embed_texts([text])[0]
|
||||
else:
|
||||
vector = None
|
||||
batch.add_data_object(
|
||||
@@ -391,7 +391,7 @@ class Weaviate(VectorStore):
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||
embeddings = embedding.embed_texts(texts) if embedding else None
|
||||
text_key = "text"
|
||||
schema = _default_schema(index_name)
|
||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||
|
||||
@@ -6,7 +6,7 @@ def test_cohere_embedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = CohereEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 2048
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def test_google_palm_embedding_documents() -> None:
|
||||
"""Test Google PaLM embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = GooglePalmEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_google_palm_embedding_documents_multiple() -> None:
|
||||
"""Test Google PaLM embeddings."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = GooglePalmEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 768
|
||||
assert len(output[1]) == 768
|
||||
|
||||
@@ -10,7 +10,7 @@ def test_huggingface_embedding_documents() -> None:
|
||||
"""Test huggingface embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = HuggingFaceEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_huggingface_instructor_embedding_documents() -> None:
|
||||
"""Test huggingface embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = HuggingFaceInstructEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ def test_huggingfacehub_embedding_documents() -> None:
|
||||
"""Test huggingfacehub embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = HuggingFaceHubEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ def test_jina_embedding_documents() -> None:
|
||||
"""Test jina embeddings for documents."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = JinaEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 512
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_llamacpp_embedding_documents() -> None:
|
||||
documents = ["foo bar"]
|
||||
model_path = get_model()
|
||||
embedding = LlamaCppEmbeddings(model_path=model_path)
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 512
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ def test_openai_embedding_documents() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 1536
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_openai_embedding_documents_multiple() -> None:
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = OpenAIEmbeddings(chunk_size=2)
|
||||
embedding.embedding_ctx_length = 8191
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
@@ -38,7 +38,7 @@ def test_openai_embedding_with_empty_string() -> None:
|
||||
"""Test openai embeddings with empty string."""
|
||||
document = ["", "abc"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(document)
|
||||
output = embedding.embed_texts(document)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
|
||||
|
||||
@@ -24,7 +24,7 @@ def test_self_hosted_huggingface_embedding_documents() -> None:
|
||||
documents = ["foo bar"]
|
||||
gpu = get_remote_instance()
|
||||
embedding = SelfHostedHuggingFaceEmbeddings(hardware=gpu)
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_self_hosted_huggingface_instructor_embedding_documents() -> None:
|
||||
documents = ["foo bar"]
|
||||
gpu = get_remote_instance()
|
||||
embedding = SelfHostedHuggingFaceInstructEmbeddings(hardware=gpu)
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
@@ -80,7 +80,7 @@ def test_self_hosted_embedding_documents() -> None:
|
||||
embedding = SelfHostedEmbeddings(
|
||||
model_load_fn=get_pipeline, hardware=gpu, inference_fn=inference_fn
|
||||
)
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 50265
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ def test_sentence_transformer_embedding_documents() -> None:
|
||||
"""Test sentence_transformer embeddings."""
|
||||
embedding = SentenceTransformerEmbeddings()
|
||||
documents = ["foo bar"]
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 384
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ def test_tensorflowhub_embedding_documents() -> None:
|
||||
"""Test tensorflowhub embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = TensorflowHubEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
output = embedding.embed_texts(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 512
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ fake_texts = ["foo", "bar", "baz"]
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
@@ -24,7 +24,7 @@ ADA_TOKEN_COUNT = 1536
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
|
||||
@@ -143,7 +143,7 @@ def test_similarity_search_by_vector(
|
||||
deeplake_datastore: DeepLake, distance_metric: str
|
||||
) -> None:
|
||||
"""Test similarity search by vector."""
|
||||
embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"])
|
||||
embeddings = FakeEmbeddings().embed_texts(["foo", "bar", "baz"])
|
||||
output = deeplake_datastore.similarity_search_by_vector(
|
||||
embeddings[1], k=1, distance_metric=distance_metric
|
||||
)
|
||||
@@ -185,7 +185,7 @@ def test_max_marginal_relevance_search(deeplake_datastore: DeepLake) -> None:
|
||||
|
||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||
|
||||
embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"])
|
||||
embeddings = FakeEmbeddings().embed_texts(["foo", "bar", "baz"])
|
||||
output = deeplake_datastore.max_marginal_relevance_search_by_vector(
|
||||
embeddings[0], k=1, fetch_k=2
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ def test_lancedb() -> None:
|
||||
embeddings = FakeEmbeddings()
|
||||
db = lancedb.connect("/tmp/lancedb")
|
||||
texts = ["text 1", "text 2", "item 3"]
|
||||
vectors = embeddings.embed_documents(texts)
|
||||
vectors = embeddings.embed_texts(texts)
|
||||
table = db.create_table(
|
||||
"my_table",
|
||||
data=[
|
||||
@@ -27,7 +27,7 @@ def test_lancedb_add_texts() -> None:
|
||||
embeddings = FakeEmbeddings()
|
||||
db = lancedb.connect("/tmp/lancedb")
|
||||
texts = ["text 1"]
|
||||
vectors = embeddings.embed_documents(texts)
|
||||
vectors = embeddings.embed_texts(texts)
|
||||
table = db.create_table(
|
||||
"my_table",
|
||||
data=[
|
||||
|
||||
@@ -24,7 +24,7 @@ ADA_TOKEN_COUNT = 1536
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.schema import Generation, LLMResult
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embedding class for tests."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return random floats."""
|
||||
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user