mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
community: Cassandra Vector Store: modernize implementation (#27253)
**Description:** This PR updates `CassandraGraphVectorStore` to be based off `CassandraVectorStore`, instead of using a custom CQL implementation. This allows users using a `CassandraVectorStore` to upgrade to a `GraphVectorStore` without having to change their database schema or re-embed documents. This PR also updates the documentation of the `GraphVectorStore` base class and contains native async implementations for the standard graph methods: `traversal_search` and `mmr_traversal_search` in `CassandraVectorStore`. **Issue:** No issue number. **Dependencies:** https://github.com/langchain-ai/langchain/pull/27078 (already-merged) **Lint and test**: - Lint and tests all pass, including existing `CassandraGraphVectorStore` tests. - Also added numerous additional tests based of the tests in `langchain-astradb` which cover many more scenarios than the existing tests for `Cassandra` and `CassandraGraphVectorStore` ** BREAKING CHANGE** Note that this is a breaking change for existing users of `CassandraGraphVectorStore`. They will need to wipe their database table and restart. However: - The interfaces have not changed. Just the underlying storage mechanism. - Any one using `langchain_community.vectorstores.Cassandra` can instead use `langchain_community.graph_vectorstores.CassandraGraphVectorStore` and they will gain Graph capabilities without having to re-embed their existing documents. This is the primary goal of this PR. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,116 +1,255 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List, Optional, Type
|
||||
"""Test of Apache Cassandra graph vector g_store class `CassandraGraphVectorStore`"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterable, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_community.graph_vectorstores.base import Node
|
||||
from langchain_community.graph_vectorstores.links import (
|
||||
METADATA_LINKS_KEY,
|
||||
Link,
|
||||
add_links,
|
||||
)
|
||||
from tests.integration_tests.cache.fake_embeddings import (
|
||||
AngularTwoDimensionalEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
|
||||
TEST_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
|
||||
def _get_graph_store(
|
||||
embedding_class: Type[Embeddings], documents: Iterable[Document] = ()
|
||||
) -> CassandraGraphVectorStore:
|
||||
import cassio
|
||||
from cassandra.cluster import Cluster
|
||||
from cassio.config import check_resolve_session, resolve_keyspace
|
||||
|
||||
node_table = "graph_test_node_table"
|
||||
edge_table = "graph_test_edge_table"
|
||||
|
||||
if any(
|
||||
env_var in os.environ
|
||||
for env_var in [
|
||||
"CASSANDRA_CONTACT_POINTS",
|
||||
"ASTRA_DB_APPLICATION_TOKEN",
|
||||
"ASTRA_DB_INIT_STRING",
|
||||
]
|
||||
):
|
||||
cassio.init(auto=True)
|
||||
session = check_resolve_session()
|
||||
else:
|
||||
cluster = Cluster()
|
||||
session = cluster.connect()
|
||||
keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE
|
||||
cassio.init(session=session, keyspace=keyspace)
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}")
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}")
|
||||
store = CassandraGraphVectorStore.from_documents(
|
||||
documents,
|
||||
embedding=embedding_class(),
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
node_table=node_table,
|
||||
targets_table=edge_table,
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
"""
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
try:
|
||||
angle = float(text)
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
except ValueError:
|
||||
# Assume: just test string, no attention is paid to values.
|
||||
return [0.0, 0.0]
|
||||
vals = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return [0.0] * self.dimension
|
||||
else:
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_d2() -> Embeddings:
|
||||
return ParserEmbeddings(dimension=2)
|
||||
|
||||
|
||||
class EarthEmbeddings(Embeddings):
|
||||
def get_vector_near(self, value: float) -> List[float]:
|
||||
base_point = [value, (1 - value**2) ** 0.5]
|
||||
fluctuation = random.random() / 100.0
|
||||
return [base_point[0] + fluctuation, base_point[1] - fluctuation]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
words = set(text.lower().split())
|
||||
if "earth" in words:
|
||||
vector = self.get_vector_near(0.9)
|
||||
elif {"planet", "world", "globe", "sphere"}.intersection(words):
|
||||
vector = self.get_vector_near(0.8)
|
||||
else:
|
||||
vector = self.get_vector_near(0.1)
|
||||
return vector
|
||||
|
||||
|
||||
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
|
||||
return [doc.id for doc in docs]
|
||||
|
||||
|
||||
def test_mmr_traversal() -> None:
|
||||
@pytest.fixture
|
||||
def graph_vector_store_docs() -> list[Document]:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
This is a set of Documents to pre-populate a graph vector store,
|
||||
with entries placed in a certain way.
|
||||
|
||||
Space of the entries (under Euclidean similarity):
|
||||
|
||||
A0 (*)
|
||||
.... AL AR <....
|
||||
: | :
|
||||
: | ^ :
|
||||
v | . v
|
||||
| :
|
||||
TR | : BL
|
||||
T0 --------------x-------------- B0
|
||||
TL | : BR
|
||||
| :
|
||||
| .
|
||||
| .
|
||||
|
|
||||
FL FR
|
||||
F0
|
||||
|
||||
the query point is meant to be at (*).
|
||||
the A are bidirectionally with B
|
||||
the A are outgoing to T
|
||||
the A are incoming from F
|
||||
The links are like: L with L, 0 with 0 and R with R.
|
||||
"""
|
||||
|
||||
docs_a = [
|
||||
Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}),
|
||||
Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}),
|
||||
Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}),
|
||||
]
|
||||
docs_b = [
|
||||
Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}),
|
||||
Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}),
|
||||
Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}),
|
||||
]
|
||||
docs_f = [
|
||||
Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}),
|
||||
Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}),
|
||||
Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}),
|
||||
]
|
||||
docs_t = [
|
||||
Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}),
|
||||
Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}),
|
||||
Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}),
|
||||
]
|
||||
for doc_a, suffix in zip(docs_a, ["l", "0", "r"]):
|
||||
add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
|
||||
add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}"))
|
||||
add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}"))
|
||||
for doc_b, suffix in zip(docs_b, ["l", "0", "r"]):
|
||||
add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
|
||||
for doc_t, suffix in zip(docs_t, ["l", "0", "r"]):
|
||||
add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}"))
|
||||
for doc_f, suffix in zip(docs_f, ["l", "0", "r"]):
|
||||
add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}"))
|
||||
return docs_a + docs_b + docs_f + docs_t
|
||||
|
||||
|
||||
class CassandraSession:
|
||||
table_name: str
|
||||
session: Any
|
||||
|
||||
def __init__(self, table_name: str, session: Any):
|
||||
self.table_name = table_name
|
||||
self.session = session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_cassandra_session(
|
||||
table_name: str, drop: bool = True
|
||||
) -> Generator[CassandraSession, None, None]:
|
||||
"""Initialize the Cassandra cluster and session"""
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
|
||||
try:
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
|
||||
" WITH replication = "
|
||||
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||||
)
|
||||
)
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
|
||||
|
||||
# Yield the session for usage
|
||||
yield CassandraSession(table_name=table_name, session=session)
|
||||
finally:
|
||||
# Ensure proper shutdown/cleanup of resources
|
||||
session.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_angular(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=AngularTwoDimensionalEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_earth(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=EarthEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_fake(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=FakeEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_d2(
|
||||
embedding_d2: Embeddings,
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=embedding_d2,
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def populated_graph_vector_store_d2(
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
graph_vector_store_docs: list[Document],
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
graph_vector_store_d2.add_documents(graph_vector_store_docs)
|
||||
yield graph_vector_store_d2
|
||||
|
||||
|
||||
def test_mmr_traversal(graph_vector_store_angular: CassandraGraphVectorStore) -> None:
|
||||
""" Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
@@ -128,140 +267,128 @@ def test_mmr_traversal() -> None:
|
||||
Both v2 and v3 are reachable via edges from v0, so once it is
|
||||
selected, those are both considered.
|
||||
"""
|
||||
store = _get_graph_store(AngularTwoDimensionalEmbeddings)
|
||||
|
||||
v0 = Document(
|
||||
v0 = Node(
|
||||
id="v0",
|
||||
page_content="-0.124",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="-0.124",
|
||||
links=[
|
||||
Link.outgoing(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
v1 = Document(
|
||||
v1 = Node(
|
||||
id="v1",
|
||||
page_content="+0.127",
|
||||
text="+0.127",
|
||||
)
|
||||
v2 = Document(
|
||||
v2 = Node(
|
||||
id="v2",
|
||||
page_content="+0.25",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="+0.25",
|
||||
links=[
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
v3 = Document(
|
||||
v3 = Node(
|
||||
id="v3",
|
||||
page_content="+1.0",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="+1.0",
|
||||
links=[
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
store.add_documents([v0, v1, v2, v3])
|
||||
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
|
||||
g_store = graph_vector_store_angular
|
||||
g_store.add_nodes([v0, v1, v2, v3])
|
||||
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
|
||||
# So it ends up picking "v1" even though it's similar to "v0".
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v1"]
|
||||
|
||||
# With max depth 0 but higher `fetch_k`, we encounter v2
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
|
||||
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
|
||||
assert _result_ids(results) == ["v0"]
|
||||
|
||||
# with k=4 we should get all of the documents.
|
||||
results = store.mmr_traversal_search("0.0", k=4)
|
||||
results = g_store.mmr_traversal_search("0.0", k=4)
|
||||
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
|
||||
|
||||
|
||||
def test_write_retrieve_keywords() -> None:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
greetings = Document(
|
||||
def test_write_retrieve_keywords(
|
||||
graph_vector_store_earth: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
greetings = Node(
|
||||
id="greetings",
|
||||
page_content="Typical Greetings",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="parent", tag="parent"),
|
||||
],
|
||||
},
|
||||
text="Typical Greetings",
|
||||
links=[
|
||||
Link.incoming(kind="parent", tag="parent"),
|
||||
],
|
||||
)
|
||||
doc1 = Document(
|
||||
|
||||
node1 = Node(
|
||||
id="doc1",
|
||||
page_content="Hello World",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="world"),
|
||||
],
|
||||
},
|
||||
text="Hello World",
|
||||
links=[
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="world"),
|
||||
],
|
||||
)
|
||||
doc2 = Document(
|
||||
|
||||
node2 = Node(
|
||||
id="doc2",
|
||||
page_content="Hello Earth",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="earth"),
|
||||
],
|
||||
},
|
||||
text="Hello Earth",
|
||||
links=[
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="earth"),
|
||||
],
|
||||
)
|
||||
store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2])
|
||||
|
||||
g_store = graph_vector_store_earth
|
||||
g_store.add_nodes(nodes=[greetings, node1, node2])
|
||||
|
||||
# Doc2 is more similar, but World and Earth are similar enough that doc1 also
|
||||
# shows up.
|
||||
results: Iterable[Document] = store.similarity_search("Earth", k=2)
|
||||
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.similarity_search("Earth", k=1)
|
||||
results = g_store.similarity_search("Earth", k=1)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=0)
|
||||
results = g_store.traversal_search("Earth", k=2, depth=0)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=1)
|
||||
results = g_store.traversal_search("Earth", k=2, depth=1)
|
||||
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth)
|
||||
results = store.traversal_search("Earth", k=1, depth=0)
|
||||
results = g_store.traversal_search("Earth", k=1, depth=0)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
|
||||
# keyword edge.
|
||||
results = store.traversal_search("Earth", k=1, depth=1)
|
||||
results = g_store.traversal_search("Earth", k=1, depth=1)
|
||||
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
|
||||
|
||||
|
||||
def test_metadata() -> None:
|
||||
store = _get_graph_store(FakeEmbeddings)
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="A",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
],
|
||||
"other": "some other field",
|
||||
},
|
||||
)
|
||||
]
|
||||
def test_metadata(graph_vector_store_fake: CassandraGraphVectorStore) -> None:
|
||||
doc_a = Node(
|
||||
id="a",
|
||||
text="A",
|
||||
metadata={"other": "some other field"},
|
||||
links=[
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
],
|
||||
)
|
||||
results = store.similarity_search("A")
|
||||
|
||||
g_store = graph_vector_store_fake
|
||||
g_store.add_nodes([doc_a])
|
||||
results = g_store.similarity_search("A")
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "a"
|
||||
metadata = results[0].metadata
|
||||
@@ -270,3 +397,274 @@ def test_metadata() -> None:
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
}
|
||||
|
||||
|
||||
class TestCassandraGraphVectorStore:
|
||||
def test_gvs_similarity_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Simple (non-graph) similarity search on a graph vector g_store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ss_response = g_store.similarity_search(query="[2, 10]", k=2)
|
||||
ss_labels = [doc.metadata["label"] for doc in ss_response]
|
||||
assert ss_labels == ["AR", "A0"]
|
||||
ss_by_v_response = g_store.similarity_search_by_vector(embedding=[2, 10], k=2)
|
||||
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
|
||||
assert ss_by_v_labels == ["AR", "A0"]
|
||||
|
||||
async def test_gvs_similarity_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Simple (non-graph) similarity search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ss_response = await g_store.asimilarity_search(query="[2, 10]", k=2)
|
||||
ss_labels = [doc.metadata["label"] for doc in ss_response]
|
||||
assert ss_labels == ["AR", "A0"]
|
||||
ss_by_v_response = await g_store.asimilarity_search_by_vector(
|
||||
embedding=[2, 10], k=2
|
||||
)
|
||||
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
|
||||
assert ss_by_v_labels == ["AR", "A0"]
|
||||
|
||||
def test_gvs_traversal_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ts_response = g_store.traversal_search(query="[2, 10]", k=2, depth=2)
|
||||
# this is a set, as some of the internals of trav.search are set-driven
|
||||
# so ordering is not deterministic:
|
||||
ts_labels = {doc.metadata["label"] for doc in ts_response}
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
async def test_gvs_traversal_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ts_labels = set()
|
||||
async for doc in g_store.atraversal_search(query="[2, 10]", k=2, depth=2):
|
||||
ts_labels.add(doc.metadata["label"])
|
||||
# this is a set, as some of the internals of trav.search are set-driven
|
||||
# so ordering is not deterministic:
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
def test_gvs_mmr_traversal_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""MMR Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = g_store.mmr_traversal_search(
|
||||
query="[2, 10]",
|
||||
k=2,
|
||||
depth=2,
|
||||
fetch_k=1,
|
||||
adjacent_k=2,
|
||||
lambda_mult=0.1,
|
||||
)
|
||||
# TODO: can this rightfully be a list (or must it be a set)?
|
||||
mt_labels = {doc.metadata["label"] for doc in mt_response}
|
||||
assert mt_labels == {"AR", "BR"}
|
||||
|
||||
async def test_gvs_mmr_traversal_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""MMR Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_labels = set()
|
||||
async for doc in g_store.ammr_traversal_search(
|
||||
query="[2, 10]",
|
||||
k=2,
|
||||
depth=2,
|
||||
fetch_k=1,
|
||||
adjacent_k=2,
|
||||
lambda_mult=0.1,
|
||||
):
|
||||
mt_labels.add(doc.metadata["label"])
|
||||
# TODO: can this rightfully be a list (or must it be a set)?
|
||||
assert mt_labels == {"AR", "BR"}
|
||||
|
||||
def test_gvs_metadata_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Metadata search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = g_store.metadata_search(
|
||||
filter={"label": "T0"},
|
||||
n=2,
|
||||
)
|
||||
doc: Document = next(iter(mt_response))
|
||||
assert doc.page_content == "[-10, 0]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "in"
|
||||
assert link.kind == "at_example"
|
||||
assert link.tag == "tag_0"
|
||||
|
||||
async def test_gvs_metadata_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Metadata search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = await g_store.ametadata_search(
|
||||
filter={"label": "T0"},
|
||||
n=2,
|
||||
)
|
||||
doc: Document = next(iter(mt_response))
|
||||
assert doc.page_content == "[-10, 0]"
|
||||
links: set[Link] = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "in"
|
||||
assert link.kind == "at_example"
|
||||
assert link.tag == "tag_0"
|
||||
|
||||
def test_gvs_get_by_document_id_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Get by document_id on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
doc = g_store.get_by_document_id(document_id="FL")
|
||||
assert doc is not None
|
||||
assert doc.page_content == "[1, -9]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "out"
|
||||
assert link.kind == "af_example"
|
||||
assert link.tag == "tag_l"
|
||||
|
||||
invalid_doc = g_store.get_by_document_id(document_id="invalid")
|
||||
assert invalid_doc is None
|
||||
|
||||
async def test_gvs_get_by_document_id_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Get by document_id on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
doc = await g_store.aget_by_document_id(document_id="FL")
|
||||
assert doc is not None
|
||||
assert doc.page_content == "[1, -9]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "out"
|
||||
assert link.kind == "af_example"
|
||||
assert link.tag == "tag_l"
|
||||
|
||||
invalid_doc = await g_store.aget_by_document_id(document_id="invalid")
|
||||
assert invalid_doc is None
|
||||
|
||||
def test_gvs_from_texts(
|
||||
self,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
g_store = graph_vector_store_d2
|
||||
g_store.add_texts(
|
||||
texts=["[1, 2]"],
|
||||
metadatas=[{"md": 1}],
|
||||
ids=["x_id"],
|
||||
)
|
||||
|
||||
hits = g_store.similarity_search("[2, 1]", k=2)
|
||||
assert len(hits) == 1
|
||||
assert hits[0].page_content == "[1, 2]"
|
||||
assert hits[0].id == "x_id"
|
||||
# there may be more re:graph structure.
|
||||
assert hits[0].metadata["md"] == "1.0"
|
||||
|
||||
def test_gvs_from_documents_containing_ids(
|
||||
self,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
the_document = Document(
|
||||
page_content="[1, 2]",
|
||||
metadata={"md": 1},
|
||||
id="x_id",
|
||||
)
|
||||
g_store = graph_vector_store_d2
|
||||
g_store.add_documents([the_document])
|
||||
hits = g_store.similarity_search("[2, 1]", k=2)
|
||||
assert len(hits) == 1
|
||||
assert hits[0].page_content == "[1, 2]"
|
||||
assert hits[0].id == "x_id"
|
||||
# there may be more re:graph structure.
|
||||
assert hits[0].metadata["md"] == "1.0"
|
||||
|
||||
def test_gvs_add_nodes_sync(
|
||||
self,
|
||||
*,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
links0 = [
|
||||
Link(kind="kA", direction="out", tag="tA"),
|
||||
Link(kind="kB", direction="bidir", tag="tB"),
|
||||
]
|
||||
links1 = [
|
||||
Link(kind="kC", direction="in", tag="tC"),
|
||||
]
|
||||
nodes = [
|
||||
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
|
||||
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
|
||||
]
|
||||
graph_vector_store_d2.add_nodes(nodes)
|
||||
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
|
||||
assert len(hits) == 2
|
||||
assert hits[0].id == "id0"
|
||||
assert hits[0].page_content == "[1, 0]"
|
||||
md0 = hits[0].metadata
|
||||
assert md0["m"] == "0.0"
|
||||
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
|
||||
|
||||
assert hits[1].id != "id0"
|
||||
assert hits[1].page_content == "[-1, 0]"
|
||||
md1 = hits[1].metadata
|
||||
assert md1["m"] == "1.0"
|
||||
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
|
||||
|
||||
async def test_gvs_add_nodes_async(
|
||||
self,
|
||||
*,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
links0 = [
|
||||
Link(kind="kA", direction="out", tag="tA"),
|
||||
Link(kind="kB", direction="bidir", tag="tB"),
|
||||
]
|
||||
links1 = [
|
||||
Link(kind="kC", direction="in", tag="tC"),
|
||||
]
|
||||
nodes = [
|
||||
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
|
||||
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
|
||||
]
|
||||
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
|
||||
pass
|
||||
|
||||
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
|
||||
assert len(hits) == 2
|
||||
assert hits[0].id == "id0"
|
||||
assert hits[0].page_content == "[1, 0]"
|
||||
md0 = hits[0].metadata
|
||||
assert md0["m"] == "0.0"
|
||||
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
|
||||
assert hits[1].id != "id0"
|
||||
assert hits[1].page_content == "[-1, 0]"
|
||||
md1 = hits[1].metadata
|
||||
assert md1["m"] == "1.0"
|
||||
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
|
||||
|
@@ -0,0 +1,269 @@
|
||||
"""Test of Upgrading to Apache Cassandra graph vector store class:
|
||||
`CassandraGraphVectorStore` from an existing table used
|
||||
by the Cassandra vector store class: `Cassandra`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterable, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_community.vectorstores import Cassandra
|
||||
|
||||
TEST_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
TABLE_NAME_ALLOW_INDEXING = "allow_graph_table"
|
||||
TABLE_NAME_DEFAULT = "default_graph_table"
|
||||
TABLE_NAME_DENY_INDEXING = "deny_graph_table"
|
||||
|
||||
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
try:
|
||||
vals = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return [0.0] * self.dimension
|
||||
else:
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_d2() -> Embeddings:
|
||||
return ParserEmbeddings(dimension=2)
|
||||
|
||||
|
||||
class CassandraSession:
|
||||
table_name: str
|
||||
session: Any
|
||||
|
||||
def __init__(self, table_name: str, session: Any):
|
||||
self.table_name = table_name
|
||||
self.session = session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_cassandra_session(
|
||||
table_name: str, drop: bool = True
|
||||
) -> Generator[CassandraSession, None, None]:
|
||||
"""Initialize the Cassandra cluster and session"""
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
|
||||
try:
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
|
||||
" WITH replication = "
|
||||
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||||
)
|
||||
)
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
|
||||
|
||||
# Yield the session for usage
|
||||
yield CassandraSession(table_name=table_name, session=session)
|
||||
finally:
|
||||
# Ensure proper shutdown/cleanup of resources
|
||||
session.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vector_store(
|
||||
embedding: Embeddings,
|
||||
table_name: str,
|
||||
setup_mode: SetupMode,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
drop: bool = True,
|
||||
) -> Generator[Cassandra, None, None]:
|
||||
with get_cassandra_session(table_name=table_name, drop=drop) as session:
|
||||
yield Cassandra(
|
||||
table_name=session.table_name,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
session=session.session,
|
||||
embedding=embedding,
|
||||
setup_mode=setup_mode,
|
||||
metadata_indexing=metadata_indexing,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_vector_store(
|
||||
embedding: Embeddings,
|
||||
table_name: str,
|
||||
setup_mode: SetupMode,
|
||||
metadata_deny_list: Optional[list[str]] = None,
|
||||
drop: bool = True,
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name, drop=drop) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
table_name=session.table_name,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
session=session.session,
|
||||
embedding=embedding,
|
||||
setup_mode=setup_mode,
|
||||
metadata_deny_list=metadata_deny_list,
|
||||
)
|
||||
|
||||
|
||||
def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]:
|
||||
if table_name == TABLE_NAME_ALLOW_INDEXING:
|
||||
return ("allowlist", ["test"])
|
||||
if table_name == TABLE_NAME_DEFAULT:
|
||||
return "all"
|
||||
if table_name == TABLE_NAME_DENY_INDEXING:
|
||||
return ("denylist", ["test"])
|
||||
msg = f"Unknown table_name: {table_name} in _vs_indexing_policy()"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class TestUpgradeToGraphVectorStore:
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
|
||||
[
|
||||
(TABLE_NAME_DEFAULT, SetupMode.SYNC, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.SYNC, ["test"]),
|
||||
(TABLE_NAME_DEFAULT, SetupMode.OFF, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, ["test"]),
|
||||
# for this one, even though the passed policy doesn't
|
||||
# match the policy used to create the collection,
|
||||
# there is no error since the SetupMode is OFF and
|
||||
# and no attempt is made to re-create the collection.
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, None),
|
||||
],
|
||||
ids=[
|
||||
"default_upgrade_no_policy_sync",
|
||||
"deny_list_upgrade_same_policy_sync",
|
||||
"default_upgrade_no_policy_off",
|
||||
"deny_list_upgrade_same_policy_off",
|
||||
"deny_list_upgrade_change_policy_off",
|
||||
],
|
||||
)
|
||||
def test_upgrade_to_gvs_success_sync(
|
||||
self,
|
||||
*,
|
||||
embedding_d2: Embeddings,
|
||||
gvs_setup_mode: SetupMode,
|
||||
table_name: str,
|
||||
gvs_metadata_deny_list: list[str],
|
||||
) -> None:
|
||||
doc_id = "AL"
|
||||
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
|
||||
|
||||
# Create vector store using SetupMode.SYNC
|
||||
with vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=SetupMode.SYNC,
|
||||
metadata_indexing=_vs_indexing_policy(table_name=table_name),
|
||||
drop=True,
|
||||
) as v_store:
|
||||
# load a document to the vector store
|
||||
v_store.add_documents([doc_al])
|
||||
|
||||
# get the document from the vector store
|
||||
v_doc = v_store.get_by_document_id(document_id=doc_id)
|
||||
assert v_doc is not None
|
||||
assert v_doc.page_content == doc_al.page_content
|
||||
|
||||
# Create a GRAPH Vector Store using the existing collection from above
|
||||
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
|
||||
with graph_vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=gvs_setup_mode,
|
||||
metadata_deny_list=gvs_metadata_deny_list,
|
||||
drop=False,
|
||||
) as gv_store:
|
||||
# get the document from the GRAPH vector store
|
||||
gv_doc = gv_store.get_by_document_id(document_id=doc_id)
|
||||
assert gv_doc is not None
|
||||
assert gv_doc.page_content == doc_al.page_content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
|
||||
[
|
||||
(TABLE_NAME_DEFAULT, SetupMode.ASYNC, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.ASYNC, ["test"]),
|
||||
],
|
||||
ids=[
|
||||
"default_upgrade_no_policy_async",
|
||||
"deny_list_upgrade_same_policy_async",
|
||||
],
|
||||
)
|
||||
async def test_upgrade_to_gvs_success_async(
|
||||
self,
|
||||
*,
|
||||
embedding_d2: Embeddings,
|
||||
gvs_setup_mode: SetupMode,
|
||||
table_name: str,
|
||||
gvs_metadata_deny_list: list[str],
|
||||
) -> None:
|
||||
doc_id = "AL"
|
||||
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
|
||||
|
||||
# Create vector store using SetupMode.ASYNC
|
||||
with vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
metadata_indexing=_vs_indexing_policy(table_name=table_name),
|
||||
drop=True,
|
||||
) as v_store:
|
||||
# load a document to the vector store
|
||||
await v_store.aadd_documents([doc_al])
|
||||
|
||||
# get the document from the vector store
|
||||
v_doc = await v_store.aget_by_document_id(document_id=doc_id)
|
||||
assert v_doc is not None
|
||||
assert v_doc.page_content == doc_al.page_content
|
||||
|
||||
# Create a GRAPH Vector Store using the existing collection from above
|
||||
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
|
||||
with graph_vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=gvs_setup_mode,
|
||||
metadata_deny_list=gvs_metadata_deny_list,
|
||||
drop=False,
|
||||
) as gv_store:
|
||||
# get the document from the GRAPH vector store
|
||||
gv_doc = await gv_store.aget_by_document_id(document_id=doc_id)
|
||||
assert gv_doc is not None
|
||||
assert gv_doc.page_content == doc_al.page_content
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user