Compare commits

...

1 Commits

Author SHA1 Message Date
Dev 2049
863baba3f0 refac 2023-05-22 13:57:17 -07:00
51 changed files with 80 additions and 80 deletions

View File

@@ -42,9 +42,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings."""
return self.base_embeddings.embed_documents(texts)
return self.base_embeddings.embed_texts(texts)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
"""Combine embeddings into final embeddings."""
@@ -55,7 +55,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
embeddings = self.embed_texts(documents)
return self.combine_embeddings(embeddings)
def _call(

View File

@@ -55,9 +55,7 @@ def _get_embeddings_from_stateful_docs(
if len(documents) and "embedded_doc" in documents[0].state:
embedded_documents = [doc.state["embedded_doc"] for doc in documents]
else:
embedded_documents = embeddings.embed_documents(
[d.page_content for d in documents]
)
embedded_documents = embeddings.embed_texts([d.page_content for d in documents])
for doc, embedding in zip(documents, embedded_documents):
doc.state["embedded_doc"] = embedding
return embedded_documents

View File

@@ -65,7 +65,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
values["client"] = Client(token=aleph_alpha_api_key)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call out to Aleph Alpha's asymmetric Document endpoint.
Args:
@@ -186,7 +186,7 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
return query_response.embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call out to Aleph Alpha's Document endpoint.
Args:

View File

@@ -6,9 +6,13 @@ from typing import List
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
"""DEPRECATED. Kept for backwards compatibility."""
return self.embed_texts(texts)
@abstractmethod
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed search texts."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:

View File

@@ -54,7 +54,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call out to Cohere's embedding endpoint.
Args:

View File

@@ -12,7 +12,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
def _get_embedding(self) -> List[float]:
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:

View File

@@ -77,7 +77,7 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:

View File

@@ -60,7 +60,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
@@ -135,7 +135,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:

View File

@@ -77,7 +77,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@@ -101,5 +101,5 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
response = self.embed_documents([text])[0]
response = self.embed_texts([text])[0]
return response

View File

@@ -71,7 +71,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
return self.client.post(on="/encode", **payload)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.

View File

@@ -99,7 +99,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using the Llama model.
Args:

View File

@@ -271,7 +271,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
headers=self.headers,
)["data"][0]["embedding"]
def embed_documents(
def embed_texts(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.

View File

@@ -164,9 +164,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
return self.content_handler.transform_output(response["Body"])
def embed_documents(
self, texts: List[str], chunk_size: int = 64
) -> List[List[float]]:
def embed_texts(self, texts: List[str], chunk_size: int = 64) -> List[List[float]]:
"""Compute doc embeddings using a SageMaker Inference Endpoint.
Args:

View File

@@ -72,7 +72,7 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:

View File

@@ -140,7 +140,7 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:

View File

@@ -43,7 +43,7 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a TensorflowHub embedding model.
Args:

View File

@@ -49,7 +49,7 @@ def create_index(
]
# create dense vectors
dense_embeds = embeddings.embed_documents(context_batch)
dense_embeds = embeddings.embed_texts(context_batch)
# create sparse vectors
sparse_embeds = sparse_encoder.encode_documents(context_batch)
for s in sparse_embeds:

View File

@@ -202,7 +202,7 @@ class AnalyticDB(VectorStore):
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
embeddings = self.embedding_function.embed_documents(list(texts))
embeddings = self.embedding_function.embed_texts(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]

View File

@@ -351,7 +351,7 @@ class Annoy(VectorStore):
embeddings = OpenAIEmbeddings()
index = Annoy.from_texts(texts, embeddings)
"""
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
return cls.__from(
texts, embeddings, embedding, metadatas, metric, trees, n_jobs, **kwargs
)

View File

@@ -119,7 +119,7 @@ class AtlasDB(VectorStore):
# Embedding upload case
if self._embedding_function is not None:
_embeddings = self._embedding_function.embed_documents(texts)
_embeddings = self._embedding_function.embed_texts(texts)
embeddings = np.stack(_embeddings)
if metadatas is None:
data = [
@@ -194,7 +194,7 @@ class AtlasDB(VectorStore):
"AtlasDB requires an embedding_function for text similarity search!"
)
_embedding = self._embedding_function.embed_documents([query])[0]
_embedding = self._embedding_function.embed_texts([query])[0]
embedding = np.array(_embedding).reshape(1, -1)
with self.project.wait_for_project_lock():
neighbors, _ = self.project.projections[0].vector_search(

View File

@@ -92,7 +92,7 @@ class Chroma(VectorStore):
self._persist_directory = persist_directory
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=self._embedding_function.embed_documents
embedding_function=self._embedding_function.embed_texts
if self._embedding_function is not None
else None,
metadata=collection_metadata,
@@ -156,7 +156,7 @@ class Chroma(VectorStore):
ids = [str(uuid.uuid1()) for _ in texts]
embeddings = None
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(list(texts))
embeddings = self._embedding_function.embed_texts(list(texts))
self._collection.add(
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
)

View File

@@ -224,7 +224,7 @@ class DeepLake(VectorStore):
embeds: Sequence[Optional[np.ndarray]] = []
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(text_list)
embeddings = self._embedding_function.embed_texts(text_list)
embeds = [np.array(e, dtype=np.float32) for e in embeddings]
else:
embeds = [None] * len(text_list)

View File

@@ -77,7 +77,7 @@ class DocArrayIndex(VectorStore, ABC):
List of ids from adding the texts into the vectorstore.
"""
ids: List[str] = []
embeddings = self.embedding.embed_documents(list(texts))
embeddings = self.embedding.embed_texts(list(texts))
for i, (t, e) in enumerate(zip(texts, embeddings)):
m = metadatas[i] if metadatas else {}
doc = self.doc_cls(text=t, embedding=e, metadata=m)

View File

@@ -167,7 +167,7 @@ class ElasticVectorSearch(VectorStore, ABC):
)
requests = []
ids = []
embeddings = self.embedding.embed_documents(list(texts))
embeddings = self.embedding.embed_texts(list(texts))
dim = len(embeddings[0])
mapping = _default_text_mapping(dim)

View File

@@ -417,7 +417,7 @@ class FAISS(VectorStore):
embeddings = OpenAIEmbeddings()
faiss = FAISS.from_texts(texts, embeddings)
"""
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
return cls.__from(
texts,
embeddings,

View File

@@ -71,7 +71,7 @@ class LanceDB(VectorStore):
# Embed texts and create documents
docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_documents(list(texts))
embeddings = self._embedding.embed_texts(list(texts))
for idx, text in enumerate(texts):
embedding = embeddings[idx]
metadata = metadatas[idx] if metadatas else {}

View File

@@ -397,7 +397,7 @@ class Milvus(VectorStore):
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
embeddings = self.embedding_func.embed_texts(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]

View File

@@ -327,7 +327,7 @@ class OpenSearchVectorSearch(VectorStore):
text_field: Document field the text of the document is stored in. Defaults
to "text".
"""
embeddings = self.embedding_function.embed_documents(list(texts))
embeddings = self.embedding_function.embed_texts(list(texts))
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
text_field = _get_kwargs_value(kwargs, "text_field", "text")
dim = len(embeddings[0])
@@ -560,7 +560,7 @@ class OpenSearchVectorSearch(VectorStore):
"ef_construction",
"m",
]
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
dim = len(embeddings[0])
# Get the index name from either from kwargs or ENV Variable

View File

@@ -212,7 +212,7 @@ class PGVector(VectorStore):
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
embeddings = self.embedding_function.embed_documents(list(texts))
embeddings = self.embedding_function.embed_texts(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]

View File

@@ -229,7 +229,7 @@ class Pinecone(VectorStore):
else:
ids_batch = [str(uuid.uuid4()) for n in range(i, i_end)]
# create embeddings
embeds = embedding.embed_documents(lines_batch)
embeds = embedding.embed_texts(lines_batch)
# prep metadata and upsert batch
if metadatas:
metadata = metadatas[i:i_end]

View File

@@ -139,7 +139,7 @@ class Qdrant(VectorStore):
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
embeddings = self.embeddings.embed_texts(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
@@ -388,7 +388,7 @@ class Qdrant(VectorStore):
from qdrant_client.http import models as rest
# Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1])
partial_embeddings = embedding.embed_texts(texts[:1])
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
@@ -418,7 +418,7 @@ class Qdrant(VectorStore):
)
# Now generate the embeddings for all the texts
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
client.upsert(
collection_name=collection_name,

View File

@@ -408,7 +408,7 @@ class Redis(VectorStore):
)
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
# Create the search index
instance._create_index(dim=len(embeddings[0]), distance_metric=distance_metric)

View File

@@ -74,7 +74,7 @@ class SupabaseVectorStore(VectorStore):
) -> List[str]:
docs = self._texts_to_documents(texts, metadatas)
vectors = self._embedding.embed_documents(list(texts))
vectors = self._embedding.embed_texts(list(texts))
return self.add_vectors(vectors, docs)
@classmethod
@@ -96,7 +96,7 @@ class SupabaseVectorStore(VectorStore):
if not table_name:
raise ValueError("Supabase document table_name is required.")
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
docs = cls._texts_to_documents(texts, metadatas)
_ids = cls._add_vectors(client, table_name, embeddings, docs)
@@ -115,7 +115,7 @@ class SupabaseVectorStore(VectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
vectors = self._embedding.embed_texts([query])
return self.similarity_search_by_vector(vectors[0], k)
def similarity_search_by_vector(
@@ -130,7 +130,7 @@ class SupabaseVectorStore(VectorStore):
def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
vectors = self._embedding.embed_documents([query])
vectors = self._embedding.embed_texts([query])
return self.similarity_search_by_vector_with_relevance_scores(vectors[0], k)
def similarity_search_by_vector_with_relevance_scores(
@@ -328,7 +328,7 @@ class SupabaseVectorStore(VectorStore):
END;
$$;```
"""
embedding = self._embedding.embed_documents([query])
embedding = self._embedding.embed_texts([query])
docs = self.max_marginal_relevance_search_by_vector(
embedding[0], k, fetch_k, lambda_mult=lambda_mult
)

View File

@@ -82,7 +82,7 @@ class Tair(VectorStore):
keys = kwargs.get("keys", None)
# Write data to tair
pipeline = self.client.pipeline(transaction=False)
embeddings = self.embedding_function.embed_documents(list(texts))
embeddings = self.embedding_function.embed_texts(list(texts))
for i, text in enumerate(texts):
# Use provided key otherwise use default key
key = keys[i] if keys else _uuid_key()
@@ -191,7 +191,7 @@ class Tair(VectorStore):
raise ValueError(f"tair failed to connect: {e}")
# Create embeddings for documents
embeddings = embedding.embed_documents(texts)
embeddings = embedding.embed_texts(texts)
tair_vector_store.create_index_if_not_exist(
len(embeddings[0]),

View File

@@ -142,7 +142,7 @@ class Weaviate(VectorStore):
)
if self._embedding is not None:
vector = self._embedding.embed_documents([text])[0]
vector = self._embedding.embed_texts([text])[0]
else:
vector = None
batch.add_data_object(
@@ -391,7 +391,7 @@ class Weaviate(VectorStore):
from weaviate.util import get_valid_uuid
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
embeddings = embedding.embed_documents(texts) if embedding else None
embeddings = embedding.embed_texts(texts) if embedding else None
text_key = "text"
schema = _default_schema(index_name)
attributes = list(metadatas[0].keys()) if metadatas else None

View File

@@ -6,7 +6,7 @@ def test_cohere_embedding_documents() -> None:
"""Test cohere embeddings."""
documents = ["foo bar"]
embedding = CohereEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 2048

View File

@@ -10,7 +10,7 @@ def test_google_palm_embedding_documents() -> None:
"""Test Google PaLM embeddings."""
documents = ["foo bar"]
embedding = GooglePalmEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768
@@ -19,7 +19,7 @@ def test_google_palm_embedding_documents_multiple() -> None:
"""Test Google PaLM embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = GooglePalmEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 3
assert len(output[0]) == 768
assert len(output[1]) == 768

View File

@@ -10,7 +10,7 @@ def test_huggingface_embedding_documents() -> None:
"""Test huggingface embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768
@@ -27,7 +27,7 @@ def test_huggingface_instructor_embedding_documents() -> None:
"""Test huggingface embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceInstructEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768

View File

@@ -8,7 +8,7 @@ def test_huggingfacehub_embedding_documents() -> None:
"""Test huggingfacehub embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceHubEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768

View File

@@ -6,7 +6,7 @@ def test_jina_embedding_documents() -> None:
"""Test jina embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = JinaEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 2
assert len(output[0]) == 512

View File

@@ -32,7 +32,7 @@ def test_llamacpp_embedding_documents() -> None:
documents = ["foo bar"]
model_path = get_model()
embedding = LlamaCppEmbeddings(model_path=model_path)
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 512

View File

@@ -9,7 +9,7 @@ def test_openai_embedding_documents() -> None:
"""Test openai embeddings."""
documents = ["foo bar"]
embedding = OpenAIEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 1536
@@ -19,7 +19,7 @@ def test_openai_embedding_documents_multiple() -> None:
documents = ["foo bar", "bar foo", "foo"]
embedding = OpenAIEmbeddings(chunk_size=2)
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 3
assert len(output[0]) == 1536
assert len(output[1]) == 1536
@@ -38,7 +38,7 @@ def test_openai_embedding_with_empty_string() -> None:
"""Test openai embeddings with empty string."""
document = ["", "abc"]
embedding = OpenAIEmbeddings()
output = embedding.embed_documents(document)
output = embedding.embed_texts(document)
assert len(output) == 2
assert len(output[0]) == 1536
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[

View File

@@ -24,7 +24,7 @@ def test_self_hosted_huggingface_embedding_documents() -> None:
documents = ["foo bar"]
gpu = get_remote_instance()
embedding = SelfHostedHuggingFaceEmbeddings(hardware=gpu)
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768
@@ -43,7 +43,7 @@ def test_self_hosted_huggingface_instructor_embedding_documents() -> None:
documents = ["foo bar"]
gpu = get_remote_instance()
embedding = SelfHostedHuggingFaceInstructEmbeddings(hardware=gpu)
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 768
@@ -80,7 +80,7 @@ def test_self_hosted_embedding_documents() -> None:
embedding = SelfHostedEmbeddings(
model_load_fn=get_pipeline, hardware=gpu, inference_fn=inference_fn
)
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 2
assert len(output[0]) == 50265

View File

@@ -9,7 +9,7 @@ def test_sentence_transformer_embedding_documents() -> None:
"""Test sentence_transformer embeddings."""
embedding = SentenceTransformerEmbeddings()
documents = ["foo bar"]
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 384

View File

@@ -6,7 +6,7 @@ def test_tensorflowhub_embedding_documents() -> None:
"""Test tensorflowhub embeddings."""
documents = ["foo bar"]
embedding = TensorflowHubEmbeddings()
output = embedding.embed_documents(documents)
output = embedding.embed_texts(documents)
assert len(output) == 1
assert len(output[0]) == 512

View File

@@ -9,7 +9,7 @@ fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]

View File

@@ -24,7 +24,7 @@ ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))

View File

@@ -143,7 +143,7 @@ def test_similarity_search_by_vector(
deeplake_datastore: DeepLake, distance_metric: str
) -> None:
"""Test similarity search by vector."""
embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"])
embeddings = FakeEmbeddings().embed_texts(["foo", "bar", "baz"])
output = deeplake_datastore.similarity_search_by_vector(
embeddings[1], k=1, distance_metric=distance_metric
)
@@ -185,7 +185,7 @@ def test_max_marginal_relevance_search(deeplake_datastore: DeepLake) -> None:
assert output == [Document(page_content="foo", metadata={"page": "0"})]
embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"])
embeddings = FakeEmbeddings().embed_texts(["foo", "bar", "baz"])
output = deeplake_datastore.max_marginal_relevance_search_by_vector(
embeddings[0], k=1, fetch_k=2
)

View File

@@ -8,7 +8,7 @@ def test_lancedb() -> None:
embeddings = FakeEmbeddings()
db = lancedb.connect("/tmp/lancedb")
texts = ["text 1", "text 2", "item 3"]
vectors = embeddings.embed_documents(texts)
vectors = embeddings.embed_texts(texts)
table = db.create_table(
"my_table",
data=[
@@ -27,7 +27,7 @@ def test_lancedb_add_texts() -> None:
embeddings = FakeEmbeddings()
db = lancedb.connect("/tmp/lancedb")
texts = ["text 1"]
vectors = embeddings.embed_documents(texts)
vectors = embeddings.embed_texts(texts)
table = db.create_table(
"my_table",
data=[

View File

@@ -24,7 +24,7 @@ ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))

View File

@@ -17,7 +17,7 @@ from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]