mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
community: Cassandra Vector Store: extend metadata-related methods (#27078)
**Description:** this PR adds a set of methods to deal with metadata associated to the vector store entries. These, while essential to the Graph-related extension of the `Cassandra` vector store, are also useful in themselves. These are (all come in their sync+async versions): - `[a]delete_by_metadata_filter` - `[a]replace_metadata` - `[a]get_by_document_id` - `[a]metadata_search` Additionally, a `[a]similarity_search_with_embedding_id_by_vector` method is introduced to better serve the store's internal working (esp. related to reranking logic). **Issue:** no issue number, but now all Document's returned bear their `.id` consistently (as a consequence of a slight refactoring in how the raw entries read from DB are made back into `Document` instances). **Dependencies:** (no new deps: packaging comes through langchain-core already; `cassio` is now required to be version 0.1.10+) **Add tests and docs** Added integration tests for the relevant newly-introduced methods. (Docs will be updated in a separate PR). **Lint and test** Lint and (updated) test all pass. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
84c05b031d
commit
d05fdd97dd
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import importlib.metadata
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -18,6 +19,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from packaging.version import Version # this is a lancghain-core dependency
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from cassandra.cluster import Session
|
from cassandra.cluster import Session
|
||||||
@ -30,6 +32,7 @@ from langchain_community.utilities.cassandra import SetupMode
|
|||||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
CVST = TypeVar("CVST", bound="Cassandra")
|
CVST = TypeVar("CVST", bound="Cassandra")
|
||||||
|
MIN_CASSIO_VERSION = Version("0.1.10")
|
||||||
|
|
||||||
|
|
||||||
class Cassandra(VectorStore):
|
class Cassandra(VectorStore):
|
||||||
@ -110,6 +113,15 @@ class Cassandra(VectorStore):
|
|||||||
"Could not import cassio python package. "
|
"Could not import cassio python package. "
|
||||||
"Please install it with `pip install cassio`."
|
"Please install it with `pip install cassio`."
|
||||||
)
|
)
|
||||||
|
cassio_version = Version(importlib.metadata.version("cassio"))
|
||||||
|
|
||||||
|
if cassio_version is not None and cassio_version < MIN_CASSIO_VERSION:
|
||||||
|
msg = (
|
||||||
|
"Cassio version not supported. Please upgrade cassio "
|
||||||
|
f"to version {MIN_CASSIO_VERSION} or higher."
|
||||||
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
if not table_name:
|
if not table_name:
|
||||||
raise ValueError("Missing required parameter 'table_name'.")
|
raise ValueError("Missing required parameter 'table_name'.")
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
@ -143,6 +155,9 @@ class Cassandra(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.session is None:
|
||||||
|
self.session = self.table.session
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return self.embedding
|
return self.embedding
|
||||||
@ -231,6 +246,70 @@ class Cassandra(VectorStore):
|
|||||||
await self.adelete_by_document_id(document_id)
|
await self.adelete_by_document_id(document_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def delete_by_metadata_filter(
|
||||||
|
self,
|
||||||
|
filter: dict[str, Any],
|
||||||
|
*,
|
||||||
|
batch_size: int = 50,
|
||||||
|
) -> int:
|
||||||
|
"""Delete all documents matching a certain metadata filtering condition.
|
||||||
|
|
||||||
|
This operation does not use the vector embeddings in any way, it simply
|
||||||
|
removes all documents whose metadata match the provided condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter: Filter on the metadata to apply. The filter cannot be empty.
|
||||||
|
batch_size: amount of deletions per each batch (until exhaustion of
|
||||||
|
the matching documents).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A number expressing the amount of deleted documents.
|
||||||
|
"""
|
||||||
|
if not filter:
|
||||||
|
msg = (
|
||||||
|
"Method `delete_by_metadata_filter` does not accept an empty "
|
||||||
|
"filter. Use the `clear()` method if you really want to empty "
|
||||||
|
"the vector store."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return self.table.find_and_delete_entries(
|
||||||
|
metadata=filter,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def adelete_by_metadata_filter(
|
||||||
|
self,
|
||||||
|
filter: dict[str, Any],
|
||||||
|
*,
|
||||||
|
batch_size: int = 50,
|
||||||
|
) -> int:
|
||||||
|
"""Delete all documents matching a certain metadata filtering condition.
|
||||||
|
|
||||||
|
This operation does not use the vector embeddings in any way, it simply
|
||||||
|
removes all documents whose metadata match the provided condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter: Filter on the metadata to apply. The filter cannot be empty.
|
||||||
|
batch_size: amount of deletions per each batch (until exhaustion of
|
||||||
|
the matching documents).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A number expressing the amount of deleted documents.
|
||||||
|
"""
|
||||||
|
if not filter:
|
||||||
|
msg = (
|
||||||
|
"Method `delete_by_metadata_filter` does not accept an empty "
|
||||||
|
"filter. Use the `clear()` method if you really want to empty "
|
||||||
|
"the vector store."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return await self.table.afind_and_delete_entries(
|
||||||
|
metadata=filter,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -333,6 +412,180 @@ class Cassandra(VectorStore):
|
|||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
def replace_metadata(
|
||||||
|
self,
|
||||||
|
id_to_metadata: dict[str, dict],
|
||||||
|
*,
|
||||||
|
batch_size: int = 50,
|
||||||
|
) -> None:
|
||||||
|
"""Replace the metadata of documents.
|
||||||
|
|
||||||
|
For each document to update, identified by its ID, the new metadata
|
||||||
|
dictionary completely replaces what is on the store. This includes
|
||||||
|
passing empty metadata `{}` to erase the currently-stored information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_to_metadata: map from the Document IDs to modify to the
|
||||||
|
new metadata for updating.
|
||||||
|
Keys in this dictionary that do not correspond to an existing
|
||||||
|
document will not cause an error, rather will result in new
|
||||||
|
rows being written into the Cassandra table but without an
|
||||||
|
associated vector: hence unreachable through vector search.
|
||||||
|
batch_size: Number of concurrent requests to send to the server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if the writes succeed (otherwise an error is raised).
|
||||||
|
"""
|
||||||
|
ids_and_metadatas = list(id_to_metadata.items())
|
||||||
|
for i in range(0, len(ids_and_metadatas), batch_size):
|
||||||
|
batch_i_m = ids_and_metadatas[i : i + batch_size]
|
||||||
|
futures = [
|
||||||
|
self.table.put_async(
|
||||||
|
row_id=doc_id,
|
||||||
|
metadata=doc_md,
|
||||||
|
)
|
||||||
|
for doc_id, doc_md in batch_i_m
|
||||||
|
]
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
return
|
||||||
|
|
||||||
|
async def areplace_metadata(
|
||||||
|
self,
|
||||||
|
id_to_metadata: dict[str, dict],
|
||||||
|
*,
|
||||||
|
concurrency: int = 50,
|
||||||
|
) -> None:
|
||||||
|
"""Replace the metadata of documents.
|
||||||
|
|
||||||
|
For each document to update, identified by its ID, the new metadata
|
||||||
|
dictionary completely replaces what is on the store. This includes
|
||||||
|
passing empty metadata `{}` to erase the currently-stored information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_to_metadata: map from the Document IDs to modify to the
|
||||||
|
new metadata for updating.
|
||||||
|
Keys in this dictionary that do not correspond to an existing
|
||||||
|
document will not cause an error, rather will result in new
|
||||||
|
rows being written into the Cassandra table but without an
|
||||||
|
associated vector: hence unreachable through vector search.
|
||||||
|
concurrency: Number of concurrent queries to the database.
|
||||||
|
Defaults to 50.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if the writes succeed (otherwise an error is raised).
|
||||||
|
"""
|
||||||
|
ids_and_metadatas = list(id_to_metadata.items())
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
async def send_concurrently(doc_id: str, doc_md: dict) -> None:
|
||||||
|
async with sem:
|
||||||
|
await self.table.aput(
|
||||||
|
row_id=doc_id,
|
||||||
|
metadata=doc_md,
|
||||||
|
)
|
||||||
|
|
||||||
|
for doc_id, doc_md in ids_and_metadatas:
|
||||||
|
tasks = [asyncio.create_task(send_concurrently(doc_id, doc_md))]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _row_to_document(row: Dict[str, Any]) -> Document:
|
||||||
|
return Document(
|
||||||
|
id=row["row_id"],
|
||||||
|
page_content=row["body_blob"],
|
||||||
|
metadata=row["metadata"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_by_document_id(self, document_id: str) -> Document | None:
|
||||||
|
"""Get by document ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: the document ID to get.
|
||||||
|
"""
|
||||||
|
row = self.table.get(row_id=document_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_document(row=row)
|
||||||
|
|
||||||
|
async def aget_by_document_id(self, document_id: str) -> Document | None:
|
||||||
|
"""Get by document ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: the document ID to get.
|
||||||
|
"""
|
||||||
|
row = await self.table.aget(row_id=document_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_document(row=row)
|
||||||
|
|
||||||
|
def metadata_search(
|
||||||
|
self,
|
||||||
|
metadata: dict[str, Any] = {}, # noqa: B006
|
||||||
|
n: int = 5,
|
||||||
|
) -> Iterable[Document]:
|
||||||
|
"""Get documents via a metadata search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: the metadata to query for.
|
||||||
|
"""
|
||||||
|
rows = self.table.find_entries(metadata=metadata, n=n)
|
||||||
|
return [self._row_to_document(row=row) for row in rows if row]
|
||||||
|
|
||||||
|
async def ametadata_search(
|
||||||
|
self,
|
||||||
|
metadata: dict[str, Any] = {}, # noqa: B006
|
||||||
|
n: int = 5,
|
||||||
|
) -> Iterable[Document]:
|
||||||
|
"""Get documents via a metadata search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: the metadata to query for.
|
||||||
|
"""
|
||||||
|
rows = await self.table.afind_entries(metadata=metadata, n=n)
|
||||||
|
return [self._row_to_document(row=row) for row in rows]
|
||||||
|
|
||||||
|
async def asimilarity_search_with_embedding_id_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
body_search: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> List[Tuple[Document, List[float], str]]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter: Filter on the metadata to apply.
|
||||||
|
body_search: Document textual search terms to apply.
|
||||||
|
Only supported by Astra DB at the moment.
|
||||||
|
Returns:
|
||||||
|
List of (Document, embedding, id), the most similar to the query vector.
|
||||||
|
"""
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if filter is not None:
|
||||||
|
kwargs["metadata"] = filter
|
||||||
|
if body_search is not None:
|
||||||
|
kwargs["body_search"] = body_search
|
||||||
|
|
||||||
|
hits = await self.table.aann_search(
|
||||||
|
vector=embedding,
|
||||||
|
n=k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
self._row_to_document(row=hit),
|
||||||
|
hit["vector"],
|
||||||
|
hit["row_id"],
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _search_to_documents(
|
def _search_to_documents(
|
||||||
hits: Iterable[Dict[str, Any]],
|
hits: Iterable[Dict[str, Any]],
|
||||||
@ -341,10 +594,7 @@ class Cassandra(VectorStore):
|
|||||||
# (1=most relevant), as required by this class' contract.
|
# (1=most relevant), as required by this class' contract.
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
Document(
|
Cassandra._row_to_document(row=hit),
|
||||||
page_content=hit["body_blob"],
|
|
||||||
metadata=hit["metadata"],
|
|
||||||
),
|
|
||||||
0.5 + 0.5 * hit["distance"],
|
0.5 + 0.5 * hit["distance"],
|
||||||
hit["row_id"],
|
hit["row_id"],
|
||||||
)
|
)
|
||||||
@ -375,7 +625,6 @@ class Cassandra(VectorStore):
|
|||||||
kwargs["metadata"] = filter
|
kwargs["metadata"] = filter
|
||||||
if body_search is not None:
|
if body_search is not None:
|
||||||
kwargs["body_search"] = body_search
|
kwargs["body_search"] = body_search
|
||||||
|
|
||||||
hits = self.table.metric_ann_search(
|
hits = self.table.metric_ann_search(
|
||||||
vector=embedding,
|
vector=embedding,
|
||||||
n=k,
|
n=k,
|
||||||
@ -712,13 +961,7 @@ class Cassandra(VectorStore):
|
|||||||
for pf_index, pf_hit in enumerate(prefetch_hits)
|
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||||
if pf_index in mmr_chosen_indices
|
if pf_index in mmr_chosen_indices
|
||||||
]
|
]
|
||||||
return [
|
return [Cassandra._row_to_document(row=hit) for hit in mmr_hits]
|
||||||
Document(
|
|
||||||
page_content=hit["body_blob"],
|
|
||||||
metadata=hit["metadata"],
|
|
||||||
)
|
|
||||||
for hit in mmr_hits
|
|
||||||
]
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
|
@ -17,6 +17,17 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_docs(documents: List[Document]) -> List[Document]:
|
||||||
|
return [_strip_doc(doc) for doc in documents]
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_doc(document: Document) -> Document:
|
||||||
|
return Document(
|
||||||
|
page_content=document.page_content,
|
||||||
|
metadata=document.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _vectorstore_from_texts(
|
def _vectorstore_from_texts(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -110,9 +121,9 @@ async def test_cassandra() -> None:
|
|||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = _vectorstore_from_texts(texts)
|
docsearch = _vectorstore_from_texts(texts)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||||
output = await docsearch.asimilarity_search("foo", k=1)
|
output = await docsearch.asimilarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||||
|
|
||||||
|
|
||||||
async def test_cassandra_with_score() -> None:
|
async def test_cassandra_with_score() -> None:
|
||||||
@ -130,13 +141,13 @@ async def test_cassandra_with_score() -> None:
|
|||||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||||
docs = [o[0] for o in output]
|
docs = [o[0] for o in output]
|
||||||
scores = [o[1] for o in output]
|
scores = [o[1] for o in output]
|
||||||
assert docs == expected_docs
|
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||||
assert scores[0] > scores[1] > scores[2]
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
output = await docsearch.asimilarity_search_with_score("foo", k=3)
|
output = await docsearch.asimilarity_search_with_score("foo", k=3)
|
||||||
docs = [o[0] for o in output]
|
docs = [o[0] for o in output]
|
||||||
scores = [o[1] for o in output]
|
scores = [o[1] for o in output]
|
||||||
assert docs == expected_docs
|
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||||
assert scores[0] > scores[1] > scores[2]
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +250,7 @@ async def test_cassandra_no_drop_async() -> None:
|
|||||||
def test_cassandra_delete() -> None:
|
def test_cassandra_delete() -> None:
|
||||||
"""Test delete methods from vector store."""
|
"""Test delete methods from vector store."""
|
||||||
texts = ["foo", "bar", "baz", "gni"]
|
texts = ["foo", "bar", "baz", "gni"]
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||||
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
|
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
|
||||||
|
|
||||||
ids = docsearch.add_texts(texts, metadatas)
|
ids = docsearch.add_texts(texts, metadatas)
|
||||||
@ -263,11 +274,21 @@ def test_cassandra_delete() -> None:
|
|||||||
output = docsearch.similarity_search("foo", k=10)
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
assert len(output) == 0
|
assert len(output) == 0
|
||||||
|
|
||||||
|
docsearch.add_texts(texts, metadatas)
|
||||||
|
num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||||
|
assert num_deleted == 2
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 2
|
||||||
|
docsearch.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.delete_by_metadata_filter({})
|
||||||
|
|
||||||
|
|
||||||
async def test_cassandra_adelete() -> None:
|
async def test_cassandra_adelete() -> None:
|
||||||
"""Test delete methods from vector store."""
|
"""Test delete methods from vector store."""
|
||||||
texts = ["foo", "bar", "baz", "gni"]
|
texts = ["foo", "bar", "baz", "gni"]
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||||
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
|
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
|
||||||
|
|
||||||
ids = await docsearch.aadd_texts(texts, metadatas)
|
ids = await docsearch.aadd_texts(texts, metadatas)
|
||||||
@ -291,6 +312,16 @@ async def test_cassandra_adelete() -> None:
|
|||||||
output = docsearch.similarity_search("foo", k=10)
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
assert len(output) == 0
|
assert len(output) == 0
|
||||||
|
|
||||||
|
await docsearch.aadd_texts(texts, metadatas)
|
||||||
|
num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||||
|
assert num_deleted == 2
|
||||||
|
output = await docsearch.asimilarity_search("foo", k=10)
|
||||||
|
assert len(output) == 2
|
||||||
|
await docsearch.aclear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await docsearch.adelete_by_metadata_filter({})
|
||||||
|
|
||||||
|
|
||||||
def test_cassandra_metadata_indexing() -> None:
|
def test_cassandra_metadata_indexing() -> None:
|
||||||
"""Test comparing metadata indexing policies."""
|
"""Test comparing metadata indexing policies."""
|
||||||
@ -316,3 +347,107 @@ def test_cassandra_metadata_indexing() -> None:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# "Non-indexed metadata fields cannot be used in queries."
|
# "Non-indexed metadata fields cannot be used in queries."
|
||||||
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
|
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_replace_metadata() -> None:
|
||||||
|
"""Test of replacing metadata."""
|
||||||
|
N_DOCS = 100
|
||||||
|
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||||
|
BATCH_SIZE = 3
|
||||||
|
|
||||||
|
vstore_f1 = _vectorstore_from_texts(
|
||||||
|
texts=[],
|
||||||
|
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||||
|
table_name="vector_test_table_indexing",
|
||||||
|
)
|
||||||
|
orig_documents = [
|
||||||
|
Document(
|
||||||
|
page_content=f"doc_{doc_i}",
|
||||||
|
id=f"doc_id_{doc_i}",
|
||||||
|
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||||
|
)
|
||||||
|
for doc_i in range(N_DOCS)
|
||||||
|
]
|
||||||
|
vstore_f1.add_documents(orig_documents)
|
||||||
|
|
||||||
|
ids_to_replace = [
|
||||||
|
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
# various kinds of replacement at play here:
|
||||||
|
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||||
|
if mode == 0:
|
||||||
|
return {}
|
||||||
|
elif mode == 1:
|
||||||
|
return {"field2": f"NEW_{doc_id}"}
|
||||||
|
elif mode == 2:
|
||||||
|
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||||
|
else:
|
||||||
|
return {"ofherf2": "post"}
|
||||||
|
|
||||||
|
ids_to_new_md = {
|
||||||
|
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||||
|
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||||
|
}
|
||||||
|
|
||||||
|
vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE)
|
||||||
|
# thorough check
|
||||||
|
expected_id_to_metadata: dict[str, dict] = {
|
||||||
|
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||||
|
**ids_to_new_md,
|
||||||
|
}
|
||||||
|
for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1):
|
||||||
|
assert hit.id is not None
|
||||||
|
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cassandra_areplace_metadata() -> None:
|
||||||
|
"""Test of replacing metadata."""
|
||||||
|
N_DOCS = 100
|
||||||
|
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||||
|
BATCH_SIZE = 3
|
||||||
|
|
||||||
|
vstore_f1 = _vectorstore_from_texts(
|
||||||
|
texts=[],
|
||||||
|
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||||
|
table_name="vector_test_table_indexing",
|
||||||
|
)
|
||||||
|
orig_documents = [
|
||||||
|
Document(
|
||||||
|
page_content=f"doc_{doc_i}",
|
||||||
|
id=f"doc_id_{doc_i}",
|
||||||
|
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||||
|
)
|
||||||
|
for doc_i in range(N_DOCS)
|
||||||
|
]
|
||||||
|
await vstore_f1.aadd_documents(orig_documents)
|
||||||
|
|
||||||
|
ids_to_replace = [
|
||||||
|
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
# various kinds of replacement at play here:
|
||||||
|
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||||
|
if mode == 0:
|
||||||
|
return {}
|
||||||
|
elif mode == 1:
|
||||||
|
return {"field2": f"NEW_{doc_id}"}
|
||||||
|
elif mode == 2:
|
||||||
|
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||||
|
else:
|
||||||
|
return {"ofherf2": "post"}
|
||||||
|
|
||||||
|
ids_to_new_md = {
|
||||||
|
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||||
|
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||||
|
}
|
||||||
|
|
||||||
|
await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE)
|
||||||
|
# thorough check
|
||||||
|
expected_id_to_metadata: dict[str, dict] = {
|
||||||
|
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||||
|
**ids_to_new_md,
|
||||||
|
}
|
||||||
|
for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1):
|
||||||
|
assert hit.id is not None
|
||||||
|
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||||
|
Loading…
Reference in New Issue
Block a user