community: Cassandra Vector Store: modernize implementation (#27253)

**Description:** 

This PR updates `CassandraGraphVectorStore` to be based off
`CassandraVectorStore`, instead of using a custom CQL implementation.
This allows users using a `CassandraVectorStore` to upgrade to a
`GraphVectorStore` without having to change their database schema or
re-embed documents.

This PR also updates the documentation of the `GraphVectorStore` base
class and contains native async implementations for the standard graph
methods: `traversal_search` and `mmr_traversal_search` in
`CassandraVectorStore`.

**Issue:** No issue number.

**Dependencies:** https://github.com/langchain-ai/langchain/pull/27078
(already-merged)

**Lint and test**: 
- Lint and tests all pass, including existing
`CassandraGraphVectorStore` tests.
- Also added numerous additional tests based of the tests in
`langchain-astradb` which cover many more scenarios than the existing
tests for `Cassandra` and `CassandraGraphVectorStore`

** BREAKING CHANGE**

Note that this is a breaking change for existing users of
`CassandraGraphVectorStore`. They will need to wipe their database table
and restart.

However:
- The interfaces have not changed. Just the underlying storage
mechanism.
- Any one using `langchain_community.vectorstores.Cassandra` can instead
use `langchain_community.graph_vectorstores.CassandraGraphVectorStore`
and they will gain Graph capabilities without having to re-embed their
existing documents. This is the primary goal of this PR.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Eric Pinzur
2024-10-22 20:11:11 +02:00
committed by GitHub
parent 0640cbf2f1
commit f636c83321
9 changed files with 4070 additions and 679 deletions

View File

@@ -1,116 +1,255 @@
import math
import os
from typing import Iterable, List, Optional, Type
"""Test of Apache Cassandra graph vector g_store class `CassandraGraphVectorStore`"""
import json
import os
import random
from contextlib import contextmanager
from typing import Any, Generator, Iterable, List, Optional
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.base import Node
from langchain_community.graph_vectorstores.links import (
METADATA_LINKS_KEY,
Link,
add_links,
)
from tests.integration_tests.cache.fake_embeddings import (
AngularTwoDimensionalEmbeddings,
FakeEmbeddings,
)
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
TEST_KEYSPACE = "graph_test_keyspace"
def _get_graph_store(
embedding_class: Type[Embeddings], documents: Iterable[Document] = ()
) -> CassandraGraphVectorStore:
import cassio
from cassandra.cluster import Cluster
from cassio.config import check_resolve_session, resolve_keyspace
node_table = "graph_test_node_table"
edge_table = "graph_test_edge_table"
if any(
env_var in os.environ
for env_var in [
"CASSANDRA_CONTACT_POINTS",
"ASTRA_DB_APPLICATION_TOKEN",
"ASTRA_DB_INIT_STRING",
]
):
cassio.init(auto=True)
session = check_resolve_session()
else:
cluster = Cluster()
session = cluster.connect()
keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE
cassio.init(session=session, keyspace=keyspace)
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}")
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}")
store = CassandraGraphVectorStore.from_documents(
documents,
embedding=embedding_class(),
session=session,
keyspace=keyspace,
node_table=node_table,
targets_table=edge_table,
)
return store
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_query(self, text: str) -> List[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
def embed_query(self, text: str) -> list[float]:
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]
vals = json.loads(text)
except json.JSONDecodeError:
return [0.0] * self.dimension
else:
assert len(vals) == self.dimension
return vals
@pytest.fixture
def embedding_d2() -> Embeddings:
return ParserEmbeddings(dimension=2)
class EarthEmbeddings(Embeddings):
def get_vector_near(self, value: float) -> List[float]:
base_point = [value, (1 - value**2) ** 0.5]
fluctuation = random.random() / 100.0
return [base_point[0] + fluctuation, base_point[1] - fluctuation]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
def embed_query(self, text: str) -> list[float]:
words = set(text.lower().split())
if "earth" in words:
vector = self.get_vector_near(0.9)
elif {"planet", "world", "globe", "sphere"}.intersection(words):
vector = self.get_vector_near(0.8)
else:
vector = self.get_vector_near(0.1)
return vector
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
return [doc.id for doc in docs]
def test_mmr_traversal() -> None:
@pytest.fixture
def graph_vector_store_docs() -> list[Document]:
"""
Test end to end construction and MMR search.
This is a set of Documents to pre-populate a graph vector store,
with entries placed in a certain way.
Space of the entries (under Euclidean similarity):
A0 (*)
.... AL AR <....
: | :
: | ^ :
v | . v
| :
TR | : BL
T0 --------------x-------------- B0
TL | : BR
| :
| .
| .
|
FL FR
F0
the query point is meant to be at (*).
the A are bidirectionally with B
the A are outgoing to T
the A are incoming from F
The links are like: L with L, 0 with 0 and R with R.
"""
docs_a = [
Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}),
Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}),
Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}),
]
docs_b = [
Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}),
Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}),
Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}),
]
docs_f = [
Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}),
Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}),
Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}),
]
docs_t = [
Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}),
Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}),
Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}),
]
for doc_a, suffix in zip(docs_a, ["l", "0", "r"]):
add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}"))
add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}"))
for doc_b, suffix in zip(docs_b, ["l", "0", "r"]):
add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
for doc_t, suffix in zip(docs_t, ["l", "0", "r"]):
add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}"))
for doc_f, suffix in zip(docs_f, ["l", "0", "r"]):
add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}"))
return docs_a + docs_b + docs_f + docs_t
class CassandraSession:
table_name: str
session: Any
def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session
@contextmanager
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
try:
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
# Yield the session for usage
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
cluster.shutdown()
@pytest.fixture(scope="function")
def graph_vector_store_angular(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=AngularTwoDimensionalEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_earth(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=EarthEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_fake(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=FakeEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_d2(
embedding_d2: Embeddings,
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=embedding_d2,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def populated_graph_vector_store_d2(
graph_vector_store_d2: CassandraGraphVectorStore,
graph_vector_store_docs: list[Document],
) -> Generator[CassandraGraphVectorStore, None, None]:
graph_vector_store_d2.add_documents(graph_vector_store_docs)
yield graph_vector_store_d2
def test_mmr_traversal(graph_vector_store_angular: CassandraGraphVectorStore) -> None:
""" Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
@@ -128,140 +267,128 @@ def test_mmr_traversal() -> None:
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
store = _get_graph_store(AngularTwoDimensionalEmbeddings)
v0 = Document(
v0 = Node(
id="v0",
page_content="-0.124",
metadata={
METADATA_LINKS_KEY: [
Link.outgoing(kind="explicit", tag="link"),
],
},
text="-0.124",
links=[
Link.outgoing(kind="explicit", tag="link"),
],
)
v1 = Document(
v1 = Node(
id="v1",
page_content="+0.127",
text="+0.127",
)
v2 = Document(
v2 = Node(
id="v2",
page_content="+0.25",
metadata={
METADATA_LINKS_KEY: [
Link.incoming(kind="explicit", tag="link"),
],
},
text="+0.25",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
v3 = Document(
v3 = Node(
id="v3",
page_content="+1.0",
metadata={
METADATA_LINKS_KEY: [
Link.incoming(kind="explicit", tag="link"),
],
},
text="+1.0",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
store.add_documents([v0, v1, v2, v3])
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
g_store = graph_vector_store_angular
g_store.add_nodes([v0, v1, v2, v3])
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]
# With max depth 0 but higher `fetch_k`, we encounter v2
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]
# with k=4 we should get all of the documents.
results = store.mmr_traversal_search("0.0", k=4)
results = g_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
def test_write_retrieve_keywords() -> None:
from langchain_openai import OpenAIEmbeddings
greetings = Document(
def test_write_retrieve_keywords(
graph_vector_store_earth: CassandraGraphVectorStore,
) -> None:
greetings = Node(
id="greetings",
page_content="Typical Greetings",
metadata={
METADATA_LINKS_KEY: [
Link.incoming(kind="parent", tag="parent"),
],
},
text="Typical Greetings",
links=[
Link.incoming(kind="parent", tag="parent"),
],
)
doc1 = Document(
node1 = Node(
id="doc1",
page_content="Hello World",
metadata={
METADATA_LINKS_KEY: [
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="world"),
],
},
text="Hello World",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="world"),
],
)
doc2 = Document(
node2 = Node(
id="doc2",
page_content="Hello Earth",
metadata={
METADATA_LINKS_KEY: [
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="earth"),
],
},
text="Hello Earth",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="earth"),
],
)
store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2])
g_store = graph_vector_store_earth
g_store.add_nodes(nodes=[greetings, node1, node2])
# Doc2 is more similar, but World and Earth are similar enough that doc1 also
# shows up.
results: Iterable[Document] = store.similarity_search("Earth", k=2)
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
assert _result_ids(results) == ["doc2", "doc1"]
results = store.similarity_search("Earth", k=1)
results = g_store.similarity_search("Earth", k=1)
assert _result_ids(results) == ["doc2"]
results = store.traversal_search("Earth", k=2, depth=0)
results = g_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]
results = store.traversal_search("Earth", k=2, depth=1)
results = g_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
# K=1 only pulls in doc2 (Hello Earth)
results = store.traversal_search("Earth", k=1, depth=0)
results = g_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
# keyword edge.
results = store.traversal_search("Earth", k=1, depth=1)
results = g_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
def test_metadata() -> None:
store = _get_graph_store(FakeEmbeddings)
store.add_documents(
[
Document(
id="a",
page_content="A",
metadata={
METADATA_LINKS_KEY: [
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
],
"other": "some other field",
},
)
]
def test_metadata(graph_vector_store_fake: CassandraGraphVectorStore) -> None:
doc_a = Node(
id="a",
text="A",
metadata={"other": "some other field"},
links=[
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
],
)
results = store.similarity_search("A")
g_store = graph_vector_store_fake
g_store.add_nodes([doc_a])
results = g_store.similarity_search("A")
assert len(results) == 1
assert results[0].id == "a"
metadata = results[0].metadata
@@ -270,3 +397,274 @@ def test_metadata() -> None:
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
}
class TestCassandraGraphVectorStore:
def test_gvs_similarity_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Simple (non-graph) similarity search on a graph vector g_store."""
g_store = populated_graph_vector_store_d2
ss_response = g_store.similarity_search(query="[2, 10]", k=2)
ss_labels = [doc.metadata["label"] for doc in ss_response]
assert ss_labels == ["AR", "A0"]
ss_by_v_response = g_store.similarity_search_by_vector(embedding=[2, 10], k=2)
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
assert ss_by_v_labels == ["AR", "A0"]
async def test_gvs_similarity_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Simple (non-graph) similarity search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ss_response = await g_store.asimilarity_search(query="[2, 10]", k=2)
ss_labels = [doc.metadata["label"] for doc in ss_response]
assert ss_labels == ["AR", "A0"]
ss_by_v_response = await g_store.asimilarity_search_by_vector(
embedding=[2, 10], k=2
)
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
assert ss_by_v_labels == ["AR", "A0"]
def test_gvs_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ts_response = g_store.traversal_search(query="[2, 10]", k=2, depth=2)
# this is a set, as some of the internals of trav.search are set-driven
# so ordering is not deterministic:
ts_labels = {doc.metadata["label"] for doc in ts_response}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
async def test_gvs_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ts_labels = set()
async for doc in g_store.atraversal_search(query="[2, 10]", k=2, depth=2):
ts_labels.add(doc.metadata["label"])
# this is a set, as some of the internals of trav.search are set-driven
# so ordering is not deterministic:
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
def test_gvs_mmr_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = g_store.mmr_traversal_search(
query="[2, 10]",
k=2,
depth=2,
fetch_k=1,
adjacent_k=2,
lambda_mult=0.1,
)
# TODO: can this rightfully be a list (or must it be a set)?
mt_labels = {doc.metadata["label"] for doc in mt_response}
assert mt_labels == {"AR", "BR"}
async def test_gvs_mmr_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_labels = set()
async for doc in g_store.ammr_traversal_search(
query="[2, 10]",
k=2,
depth=2,
fetch_k=1,
adjacent_k=2,
lambda_mult=0.1,
):
mt_labels.add(doc.metadata["label"])
# TODO: can this rightfully be a list (or must it be a set)?
assert mt_labels == {"AR", "BR"}
def test_gvs_metadata_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Metadata search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = g_store.metadata_search(
filter={"label": "T0"},
n=2,
)
doc: Document = next(iter(mt_response))
assert doc.page_content == "[-10, 0]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "in"
assert link.kind == "at_example"
assert link.tag == "tag_0"
async def test_gvs_metadata_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Metadata search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = await g_store.ametadata_search(
filter={"label": "T0"},
n=2,
)
doc: Document = next(iter(mt_response))
assert doc.page_content == "[-10, 0]"
links: set[Link] = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "in"
assert link.kind == "at_example"
assert link.tag == "tag_0"
def test_gvs_get_by_document_id_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Get by document_id on a graph vector store."""
g_store = populated_graph_vector_store_d2
doc = g_store.get_by_document_id(document_id="FL")
assert doc is not None
assert doc.page_content == "[1, -9]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "out"
assert link.kind == "af_example"
assert link.tag == "tag_l"
invalid_doc = g_store.get_by_document_id(document_id="invalid")
assert invalid_doc is None
async def test_gvs_get_by_document_id_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Get by document_id on a graph vector store."""
g_store = populated_graph_vector_store_d2
doc = await g_store.aget_by_document_id(document_id="FL")
assert doc is not None
assert doc.page_content == "[1, -9]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "out"
assert link.kind == "af_example"
assert link.tag == "tag_l"
invalid_doc = await g_store.aget_by_document_id(document_id="invalid")
assert invalid_doc is None
def test_gvs_from_texts(
self,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
g_store = graph_vector_store_d2
g_store.add_texts(
texts=["[1, 2]"],
metadatas=[{"md": 1}],
ids=["x_id"],
)
hits = g_store.similarity_search("[2, 1]", k=2)
assert len(hits) == 1
assert hits[0].page_content == "[1, 2]"
assert hits[0].id == "x_id"
# there may be more re:graph structure.
assert hits[0].metadata["md"] == "1.0"
def test_gvs_from_documents_containing_ids(
self,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
the_document = Document(
page_content="[1, 2]",
metadata={"md": 1},
id="x_id",
)
g_store = graph_vector_store_d2
g_store.add_documents([the_document])
hits = g_store.similarity_search("[2, 1]", k=2)
assert len(hits) == 1
assert hits[0].page_content == "[1, 2]"
assert hits[0].id == "x_id"
# there may be more re:graph structure.
assert hits[0].metadata["md"] == "1.0"
def test_gvs_add_nodes_sync(
self,
*,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
]
links1 = [
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
graph_vector_store_d2.add_nodes(nodes)
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
async def test_gvs_add_nodes_async(
self,
*,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
]
links1 = [
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
pass
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")

View File

@@ -0,0 +1,269 @@
"""Test of Upgrading to Apache Cassandra graph vector store class:
`CassandraGraphVectorStore` from an existing table used
by the Cassandra vector store class: `Cassandra`
"""
from __future__ import annotations
import json
import os
from contextlib import contextmanager
from typing import Any, Generator, Iterable, Optional, Tuple, Union
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.utilities.cassandra import SetupMode
from langchain_community.vectorstores import Cassandra
TEST_KEYSPACE = "graph_test_keyspace"
TABLE_NAME_ALLOW_INDEXING = "allow_graph_table"
TABLE_NAME_DEFAULT = "default_graph_table"
TABLE_NAME_DENY_INDEXING = "deny_graph_table"
class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
"""
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> list[float]:
try:
vals = json.loads(text)
except json.JSONDecodeError:
return [0.0] * self.dimension
else:
assert len(vals) == self.dimension
return vals
async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)
@pytest.fixture
def embedding_d2() -> Embeddings:
return ParserEmbeddings(dimension=2)
class CassandraSession:
table_name: str
session: Any
def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session
@contextmanager
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
try:
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
# Yield the session for usage
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
cluster.shutdown()
@contextmanager
def vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
drop: bool = True,
) -> Generator[Cassandra, None, None]:
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield Cassandra(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_indexing=metadata_indexing,
)
@contextmanager
def graph_vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_deny_list: Optional[list[str]] = None,
drop: bool = True,
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield CassandraGraphVectorStore(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_deny_list=metadata_deny_list,
)
def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]:
if table_name == TABLE_NAME_ALLOW_INDEXING:
return ("allowlist", ["test"])
if table_name == TABLE_NAME_DEFAULT:
return "all"
if table_name == TABLE_NAME_DENY_INDEXING:
return ("denylist", ["test"])
msg = f"Unknown table_name: {table_name} in _vs_indexing_policy()"
raise ValueError(msg)
class TestUpgradeToGraphVectorStore:
@pytest.mark.parametrize(
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
[
(TABLE_NAME_DEFAULT, SetupMode.SYNC, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.SYNC, ["test"]),
(TABLE_NAME_DEFAULT, SetupMode.OFF, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, ["test"]),
# for this one, even though the passed policy doesn't
# match the policy used to create the collection,
# there is no error since the SetupMode is OFF and
# and no attempt is made to re-create the collection.
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, None),
],
ids=[
"default_upgrade_no_policy_sync",
"deny_list_upgrade_same_policy_sync",
"default_upgrade_no_policy_off",
"deny_list_upgrade_same_policy_off",
"deny_list_upgrade_change_policy_off",
],
)
def test_upgrade_to_gvs_success_sync(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
) -> None:
doc_id = "AL"
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
# Create vector store using SetupMode.SYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.SYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
drop=True,
) as v_store:
# load a document to the vector store
v_store.add_documents([doc_al])
# get the document from the vector store
v_doc = v_store.get_by_document_id(document_id=doc_id)
assert v_doc is not None
assert v_doc.page_content == doc_al.page_content
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
drop=False,
) as gv_store:
# get the document from the GRAPH vector store
gv_doc = gv_store.get_by_document_id(document_id=doc_id)
assert gv_doc is not None
assert gv_doc.page_content == doc_al.page_content
@pytest.mark.parametrize(
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
[
(TABLE_NAME_DEFAULT, SetupMode.ASYNC, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.ASYNC, ["test"]),
],
ids=[
"default_upgrade_no_policy_async",
"deny_list_upgrade_same_policy_async",
],
)
async def test_upgrade_to_gvs_success_async(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
) -> None:
doc_id = "AL"
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
# Create vector store using SetupMode.ASYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.ASYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
drop=True,
) as v_store:
# load a document to the vector store
await v_store.aadd_documents([doc_al])
# get the document from the vector store
v_doc = await v_store.aget_by_document_id(document_id=doc_id)
assert v_doc is not None
assert v_doc.page_content == doc_al.page_content
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
drop=False,
) as gv_store:
# get the document from the GRAPH vector store
gv_doc = await gv_store.aget_by_document_id(document_id=doc_id)
assert gv_doc is not None
assert gv_doc.page_content == doc_al.page_content