mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
community: Cassandra Vector Store: extend metadata-related methods (#27078)
**Description:** this PR adds a set of methods to deal with metadata associated to the vector store entries. These, while essential to the Graph-related extension of the `Cassandra` vector store, are also useful in themselves. These are (all come in their sync+async versions): - `[a]delete_by_metadata_filter` - `[a]replace_metadata` - `[a]get_by_document_id` - `[a]metadata_search` Additionally, a `[a]similarity_search_with_embedding_id_by_vector` method is introduced to better serve the store's internal working (esp. related to reranking logic). **Issue:** no issue number, but now all Document's returned bear their `.id` consistently (as a consequence of a slight refactoring in how the raw entries read from DB are made back into `Document` instances). **Dependencies:** (no new deps: packaging comes through langchain-core already; `cassio` is now required to be version 0.1.10+) **Add tests and docs** Added integration tests for the relevant newly-introduced methods. (Docs will be updated in a separate PR). **Lint and test** Lint and (updated) test all pass. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -17,6 +17,17 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
)
|
||||
|
||||
|
||||
def _strip_docs(documents: List[Document]) -> List[Document]:
|
||||
return [_strip_doc(doc) for doc in documents]
|
||||
|
||||
|
||||
def _strip_doc(document: Document) -> Document:
|
||||
return Document(
|
||||
page_content=document.page_content,
|
||||
metadata=document.metadata,
|
||||
)
|
||||
|
||||
|
||||
def _vectorstore_from_texts(
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@@ -110,9 +121,9 @@ async def test_cassandra() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = _vectorstore_from_texts(texts)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||
|
||||
|
||||
async def test_cassandra_with_score() -> None:
|
||||
@@ -130,13 +141,13 @@ async def test_cassandra_with_score() -> None:
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
output = await docsearch.asimilarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
@@ -239,7 +250,7 @@ async def test_cassandra_no_drop_async() -> None:
|
||||
def test_cassandra_delete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
|
||||
|
||||
ids = docsearch.add_texts(texts, metadatas)
|
||||
@@ -263,11 +274,21 @@ def test_cassandra_delete() -> None:
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||
assert num_deleted == 2
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
docsearch.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.delete_by_metadata_filter({})
|
||||
|
||||
|
||||
async def test_cassandra_adelete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
|
||||
|
||||
ids = await docsearch.aadd_texts(texts, metadatas)
|
||||
@@ -291,6 +312,16 @@ async def test_cassandra_adelete() -> None:
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
await docsearch.aadd_texts(texts, metadatas)
|
||||
num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||
assert num_deleted == 2
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
await docsearch.aclear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await docsearch.adelete_by_metadata_filter({})
|
||||
|
||||
|
||||
def test_cassandra_metadata_indexing() -> None:
|
||||
"""Test comparing metadata indexing policies."""
|
||||
@@ -316,3 +347,107 @@ def test_cassandra_metadata_indexing() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
# "Non-indexed metadata fields cannot be used in queries."
|
||||
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
|
||||
|
||||
|
||||
def test_cassandra_replace_metadata() -> None:
|
||||
"""Test of replacing metadata."""
|
||||
N_DOCS = 100
|
||||
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||
BATCH_SIZE = 3
|
||||
|
||||
vstore_f1 = _vectorstore_from_texts(
|
||||
texts=[],
|
||||
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||
table_name="vector_test_table_indexing",
|
||||
)
|
||||
orig_documents = [
|
||||
Document(
|
||||
page_content=f"doc_{doc_i}",
|
||||
id=f"doc_id_{doc_i}",
|
||||
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||
)
|
||||
for doc_i in range(N_DOCS)
|
||||
]
|
||||
vstore_f1.add_documents(orig_documents)
|
||||
|
||||
ids_to_replace = [
|
||||
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||
]
|
||||
|
||||
# various kinds of replacement at play here:
|
||||
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||
if mode == 0:
|
||||
return {}
|
||||
elif mode == 1:
|
||||
return {"field2": f"NEW_{doc_id}"}
|
||||
elif mode == 2:
|
||||
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||
else:
|
||||
return {"ofherf2": "post"}
|
||||
|
||||
ids_to_new_md = {
|
||||
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||
}
|
||||
|
||||
vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE)
|
||||
# thorough check
|
||||
expected_id_to_metadata: dict[str, dict] = {
|
||||
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||
**ids_to_new_md,
|
||||
}
|
||||
for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1):
|
||||
assert hit.id is not None
|
||||
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||
|
||||
|
||||
async def test_cassandra_areplace_metadata() -> None:
|
||||
"""Test of replacing metadata."""
|
||||
N_DOCS = 100
|
||||
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||
BATCH_SIZE = 3
|
||||
|
||||
vstore_f1 = _vectorstore_from_texts(
|
||||
texts=[],
|
||||
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||
table_name="vector_test_table_indexing",
|
||||
)
|
||||
orig_documents = [
|
||||
Document(
|
||||
page_content=f"doc_{doc_i}",
|
||||
id=f"doc_id_{doc_i}",
|
||||
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||
)
|
||||
for doc_i in range(N_DOCS)
|
||||
]
|
||||
await vstore_f1.aadd_documents(orig_documents)
|
||||
|
||||
ids_to_replace = [
|
||||
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||
]
|
||||
|
||||
# various kinds of replacement at play here:
|
||||
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||
if mode == 0:
|
||||
return {}
|
||||
elif mode == 1:
|
||||
return {"field2": f"NEW_{doc_id}"}
|
||||
elif mode == 2:
|
||||
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||
else:
|
||||
return {"ofherf2": "post"}
|
||||
|
||||
ids_to_new_md = {
|
||||
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||
}
|
||||
|
||||
await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE)
|
||||
# thorough check
|
||||
expected_id_to_metadata: dict[str, dict] = {
|
||||
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||
**ids_to_new_md,
|
||||
}
|
||||
for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1):
|
||||
assert hit.id is not None
|
||||
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||
|
Reference in New Issue
Block a user