community: Cassandra Vector Store: modernize implementation (#27253)

**Description:** 

This PR updates `CassandraGraphVectorStore` to be based off
`CassandraVectorStore`, instead of using a custom CQL implementation.
This allows users using a `CassandraVectorStore` to upgrade to a
`GraphVectorStore` without having to change their database schema or
re-embed documents.

This PR also updates the documentation of the `GraphVectorStore` base
class and contains native async implementations for the standard graph
methods: `traversal_search` and `mmr_traversal_search` in
`CassandraVectorStore`.

**Issue:** No issue number.

**Dependencies:** https://github.com/langchain-ai/langchain/pull/27078
(already-merged)

**Lint and test**: 
- Lint and tests all pass, including existing
`CassandraGraphVectorStore` tests.
- Also added numerous additional tests based of the tests in
`langchain-astradb` which cover many more scenarios than the existing
tests for `Cassandra` and `CassandraGraphVectorStore`

** BREAKING CHANGE**

Note that this is a breaking change for existing users of
`CassandraGraphVectorStore`. They will need to wipe their database table
and restart.

However:
- The interfaces have not changed. Just the underlying storage
mechanism.
- Any one using `langchain_community.vectorstores.Cassandra` can instead
use `langchain_community.graph_vectorstores.CassandraGraphVectorStore`
and they will gain Graph capabilities without having to re-embed their
existing documents. This is the primary goal of this PR.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Eric Pinzur 2024-10-22 20:11:11 +02:00 committed by GitHub
parent 0640cbf2f1
commit f636c83321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 4070 additions and 679 deletions

View File

@ -144,6 +144,7 @@ from langchain_community.graph_vectorstores.cassandra import CassandraGraphVecto
from langchain_community.graph_vectorstores.links import (
Link,
)
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
__all__ = [
"GraphVectorStore",
@ -151,4 +152,5 @@ __all__ = [
"Node",
"Link",
"CassandraGraphVectorStore",
"MmrHelper",
]

View File

@ -1,11 +1,13 @@
from __future__ import annotations
import logging
from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import (
Any,
ClassVar,
Optional,
Sequence,
)
from langchain_core._api import beta
@ -21,6 +23,8 @@ from pydantic import Field
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
logger = logging.getLogger(__name__)
def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
@ -158,6 +162,7 @@ class GraphVectorStore(VectorStore):
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
async def aadd_nodes(
@ -169,6 +174,7 @@ class GraphVectorStore(VectorStore):
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
done = object()
@ -383,6 +389,7 @@ class GraphVectorStore(VectorStore):
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from traversing this graph store.
@ -396,8 +403,10 @@ class GraphVectorStore(VectorStore):
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Retrieved documents.
Collection of retrieved documents.
"""
async def atraversal_search(
@ -406,6 +415,7 @@ class GraphVectorStore(VectorStore):
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from traversing this graph store.
@ -419,12 +429,20 @@ class GraphVectorStore(VectorStore):
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Retrieved documents.
Collection of retrieved documents.
"""
iterator = iter(
await run_in_executor(
None, self.traversal_search, query, k=k, depth=depth, **kwargs
None,
self.traversal_search,
query,
k=k,
depth=depth,
filter=filter,
**kwargs,
)
)
done = object()
@ -439,12 +457,14 @@ class GraphVectorStore(VectorStore):
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
@ -459,6 +479,10 @@ class GraphVectorStore(VectorStore):
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
@ -471,18 +495,22 @@ class GraphVectorStore(VectorStore):
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
async def ammr_traversal_search(
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
@ -497,6 +525,10 @@ class GraphVectorStore(VectorStore):
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
@ -509,18 +541,22 @@ class GraphVectorStore(VectorStore):
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
iterator = iter(
await run_in_executor(
None,
self.mmr_traversal_search,
query,
initial_roots=initial_roots,
k=k,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
depth=depth,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
filter=filter,
**kwargs,
)
)
@ -544,6 +580,11 @@ class GraphVectorStore(VectorStore):
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
if kwargs.get("depth", 0) > 0:
logger.warning(
"'mmr' search started with depth > 0. "
"Maybe you meant to do a 'mmr_traversal' search?"
)
return list(
self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
@ -573,7 +614,7 @@ class GraphVectorStore(VectorStore):
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
"'mmr', 'traversal', or 'mmr_traversal'."
)
async def asearch(
@ -590,11 +631,13 @@ class GraphVectorStore(VectorStore):
return await self.amax_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return [doc async for doc in self.atraversal_search(query, **kwargs)]
elif search_type == "mmr_traversal":
return [doc async for doc in self.ammr_traversal_search(query, **kwargs)]
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
"'mmr', 'traversal', or 'mmr_traversal'."
)
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
@ -606,13 +649,14 @@ class GraphVectorStore(VectorStore):
- search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
``similarity_score_threshold``.
Can be ``traversal`` (default), ``similarity``, ``mmr``,
``mmr_traversal``, or ``similarity_score_threshold``.
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
- k(int): Amount of documents to return (Default: 4).
- depth(int): The maximum depth of edges to traverse (Default: 1).
Only applies to search_type: ``traversal`` and ``mmr_traversal``.
- score_threshold(float): Minimum relevance threshold
for similarity_score_threshold.
- fetch_k(int): Amount of documents to pass to MMR algorithm
@ -629,21 +673,21 @@ class GraphVectorStore(VectorStore):
# Retrieve documents traversing edges
docsearch.as_retriever(
search_type="traversal",
search_kwargs={'k': 6, 'depth': 3}
search_kwargs={'k': 6, 'depth': 2}
)
# Retrieve more documents with higher diversity
# Retrieve documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
search_type="mmr_traversal",
search_kwargs={'k': 6, 'lambda_mult': 0.25, 'depth': 2}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
search_type="mmr_traversal",
search_kwargs={'k': 5, 'fetch_k': 50, 'depth': 2}
)
# Only retrieve documents that have a relevance score
@ -657,7 +701,7 @@ class GraphVectorStore(VectorStore):
docsearch.as_retriever(search_kwargs={'k': 1})
"""
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
@ -793,7 +837,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
""" # noqa: E501
vectorstore: GraphVectorStore
vector_store: GraphVectorStore
"""GraphVectorStore to use for retrieval."""
search_type: str = "traversal"
"""Type of search to perform. Defaults to "traversal"."""
@ -809,10 +853,10 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal":
return list(
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
)
else:
return super()._get_relevant_documents(query, run_manager=run_manager)
@ -823,14 +867,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
if self.search_type == "traversal":
return [
doc
async for doc in self.vectorstore.atraversal_search(
async for doc in self.vector_store.atraversal_search(
query, **self.search_kwargs
)
]
elif self.search_type == "mmr_traversal":
return [
doc
async for doc in self.vectorstore.ammr_traversal_search(
async for doc in self.vector_store.ammr_traversal_search(
query, **self.search_kwargs
)
]

View File

@ -0,0 +1,272 @@
"""Tools for the Graph Traversal Maximal Marginal Relevance (MMR) reranking."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, Iterable
import numpy as np
from langchain_community.utils.math import cosine_similarity
if TYPE_CHECKING:
from numpy.typing import NDArray
def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
emb_array = np.array(embedding, dtype=np.float32)
if emb_array.ndim == 1:
emb_array = np.expand_dims(emb_array, axis=0)
return emb_array
NEG_INF = float("-inf")
@dataclasses.dataclass
class _Candidate:
id: str
similarity: float
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
def __post_init__(self) -> None:
self.score = self.weighted_similarity - self.weighted_redundancy
def update_redundancy(self, new_weighted_redundancy: float) -> None:
if new_weighted_redundancy > self.weighted_redundancy:
self.weighted_redundancy = new_weighted_redundancy
self.score = self.weighted_similarity - self.weighted_redundancy
class MmrHelper:
"""Helper for executing an MMR traversal query.
Args:
query_embedding: The embedding of the query to use for scoring.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
"""
dimensions: int
"""Dimensions of the embedding."""
query_embedding: NDArray[np.float32]
"""Embedding of the query as a (1,dim) ndarray."""
lambda_mult: float
"""Number between 0 and 1.
Determines the degree of diversity among the results with 0 corresponding to
maximum diversity and 1 to minimum diversity."""
lambda_mult_complement: float
"""1 - lambda_mult."""
score_threshold: float
"""Only documents with a score greater than or equal to this will be chosen."""
selected_ids: list[str]
"""List of selected IDs (in selection order)."""
selected_mmr_scores: list[float]
"""List of MMR score at the time each document is selected."""
selected_similarity_scores: list[float]
"""List of similarity score for each selected document."""
selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""
candidate_id_to_index: dict[str, int]
"""Dictionary of candidate IDs to indices in candidates and candidate_embeddings."""
candidates: list[_Candidate]
"""List containing information about candidates.
Same order as rows in `candidate_embeddings`.
"""
candidate_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each candidate."""
best_score: float
best_id: str | None
def __init__(
self,
k: int,
query_embedding: list[float],
lambda_mult: float = 0.5,
score_threshold: float = NEG_INF,
) -> None:
"""Create a new Traversal MMR helper."""
self.query_embedding = _emb_to_ndarray(query_embedding)
self.dimensions = self.query_embedding.shape[1]
self.lambda_mult = lambda_mult
self.lambda_mult_complement = 1 - lambda_mult
self.score_threshold = score_threshold
self.selected_ids = []
self.selected_similarity_scores = []
self.selected_mmr_scores = []
# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
self.candidate_id_to_index = {}
# List of the candidates.
self.candidates = []
# numpy n-dimensional array of the candidate embeddings.
self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32)
self.best_score = NEG_INF
self.best_id = None
def candidate_ids(self) -> Iterable[str]:
"""Return the IDs of the candidates."""
return self.candidate_id_to_index.keys()
def _already_selected_embeddings(self) -> NDArray[np.float32]:
"""Return the selected embeddings sliced to the already assigned values."""
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
"""Pop the candidate with the given ID.
Returns:
The similarity score and embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
if self.candidates[index].id != candidate_id:
msg = (
"ID in self.candidate_id_to_index doesn't match the ID of the "
"corresponding index in self.candidates"
)
raise ValueError(msg)
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
# Swap that index with the last index in the candidates and
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1
similarity = 0.0
if index == last_index:
# Already the last item. We don't need to swap.
similarity = self.candidates.pop().similarity
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
similarity = self.candidates[index].similarity
old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[
0
]
return similarity, embedding
def pop_best(self) -> str | None:
"""Select and pop the best item being considered.
Updates the consideration set based on it.
Returns:
A tuple containing the ID of the best item.
"""
if self.best_id is None or self.best_score < self.score_threshold:
return None
# Get the selection and remove from candidates.
selected_id = self.best_id
selected_similarity, selected_embedding = self._pop_candidate(selected_id)
# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_mmr_scores.append(self.best_score)
self.selected_similarity_scores.append(selected_similarity)
self.selected_embeddings[selection_index] = selected_embedding
# Reset the best score / best ID.
self.best_score = NEG_INF
self.best_id = None
# Update the candidates redundancy, tracking the best node.
if self.candidate_embeddings.shape[0] > 0:
similarity = cosine_similarity(
self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0)
)
for index, candidate in enumerate(self.candidates):
candidate.update_redundancy(similarity[index][0])
if candidate.score > self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
return selected_id
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.
include_ids_set = set(candidates.keys())
include_ids_set.difference_update(self.selected_ids)
include_ids_set.difference_update(self.candidate_id_to_index.keys())
include_ids = list(include_ids_set)
# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(include_ids),
self.dimensions,
)
)
offset = self.candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
if candidate_id in include_ids:
self.candidate_id_to_index[candidate_id] = offset + index
embedding = candidates[candidate_id]
new_embeddings[index] = embedding
# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self.query_embedding)
# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
self.candidates.append(candidate)
if candidate.score >= self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
# Add the new embeddings to the candidate set.
self.candidate_embeddings = np.vstack(
(
self.candidate_embeddings,
new_embeddings,
)
)

View File

@ -4,6 +4,7 @@ import asyncio
import importlib.metadata
import typing
import uuid
import warnings
from typing import (
Any,
Awaitable,
@ -501,10 +502,13 @@ class Cassandra(VectorStore):
)
def get_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: the document ID to get.
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
row = self.table.get(row_id=document_id)
if row is None:
@ -512,10 +516,13 @@ class Cassandra(VectorStore):
return self._row_to_document(row=row)
async def aget_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: the document ID to get.
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
row = await self.table.aget(row_id=document_id)
if row is None:
@ -524,28 +531,30 @@ class Cassandra(VectorStore):
def metadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
filter: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
filter: the metadata to query for.
n: the maximum number of documents to return.
"""
rows = self.table.find_entries(metadata=metadata, n=n)
rows = self.table.find_entries(metadata=filter, n=n)
return [self._row_to_document(row=row) for row in rows if row]
async def ametadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
filter: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
filter: the metadata to query for.
n: the maximum number of documents to return.
"""
rows = await self.table.afind_entries(metadata=metadata, n=n)
rows = await self.table.afind_entries(metadata=filter, n=n)
return [self._row_to_document(row=row) for row in rows]
async def asimilarity_search_with_embedding_id_by_vector(
@ -1126,6 +1135,24 @@ class Cassandra(VectorStore):
body_search=body_search,
)
@staticmethod
def _build_docs_from_texts(
texts: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[Document]:
docs: List[Document] = []
for i, text in enumerate(texts):
doc = Document(
page_content=text,
)
if metadatas is not None:
doc.metadata = metadatas[i]
if ids is not None:
doc.id = ids[i]
docs.append(doc)
return docs
@classmethod
def from_texts(
cls: Type[CVST],
@ -1137,7 +1164,6 @@ class Cassandra(VectorStore):
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
@ -1155,16 +1181,32 @@ class Cassandra(VectorStore):
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Returns:
a Cassandra vector store.
"""
store = cls(
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)
return cls.from_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
@ -1172,11 +1214,8 @@ class Cassandra(VectorStore):
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
store.add_texts(
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
)
return store
@classmethod
async def afrom_texts(
@ -1189,7 +1228,6 @@ class Cassandra(VectorStore):
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
@ -1207,29 +1245,51 @@ class Cassandra(VectorStore):
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
concurrency: Number of concurrent queries to send to the database.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Returns:
a Cassandra vector store.
"""
store = cls(
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)
return await cls.afrom_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
await store.aadd_texts(
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
)
return store
@staticmethod
def _add_ids_to_docs(
docs: List[Document],
ids: Optional[List[str]] = None,
) -> List[Document]:
if ids is not None:
for doc, doc_id in zip(docs, ids):
doc.id = doc_id
return docs
@classmethod
def from_documents(
@ -1241,7 +1301,6 @@ class Cassandra(VectorStore):
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
@ -1258,31 +1317,48 @@ class Cassandra(VectorStore):
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Returns:
a Cassandra vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
if ids is not None:
warnings.warn(
(
"Parameter `ids` to Cassandra's `from_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)
store = cls(
embedding=embedding,
metadatas=metadatas,
session=session,
keyspace=keyspace,
table_name=table_name,
ids=ids,
batch_size=batch_size,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store
@classmethod
async def afrom_documents(
@ -1294,7 +1370,6 @@ class Cassandra(VectorStore):
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
@ -1311,31 +1386,51 @@ class Cassandra(VectorStore):
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
concurrency: Number of concurrent queries to send to the database.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Returns:
a Cassandra vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await cls.afrom_texts(
texts=texts,
if ids is not None:
warnings.warn(
(
"Parameter `ids` to Cassandra's `afrom_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)
store = cls(
embedding=embedding,
metadatas=metadatas,
session=session,
keyspace=keyspace,
table_name=table_name,
ids=ids,
concurrency=concurrency,
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
await store.aadd_documents(
documents=cls._add_ids_to_docs(docs=documents, ids=ids)
)
return store
def as_retriever(
self,

View File

@ -1,116 +1,255 @@
import math
import os
from typing import Iterable, List, Optional, Type
"""Test of Apache Cassandra graph vector g_store class `CassandraGraphVectorStore`"""
import json
import os
import random
from contextlib import contextmanager
from typing import Any, Generator, Iterable, List, Optional
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
def _get_graph_store(
embedding_class: Type[Embeddings], documents: Iterable[Document] = ()
) -> CassandraGraphVectorStore:
import cassio
from cassandra.cluster import Cluster
from cassio.config import check_resolve_session, resolve_keyspace
node_table = "graph_test_node_table"
edge_table = "graph_test_edge_table"
if any(
env_var in os.environ
for env_var in [
"CASSANDRA_CONTACT_POINTS",
"ASTRA_DB_APPLICATION_TOKEN",
"ASTRA_DB_INIT_STRING",
]
):
cassio.init(auto=True)
session = check_resolve_session()
else:
cluster = Cluster()
session = cluster.connect()
keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE
cassio.init(session=session, keyspace=keyspace)
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
from langchain_community.graph_vectorstores.base import Node
from langchain_community.graph_vectorstores.links import (
METADATA_LINKS_KEY,
Link,
add_links,
)
from tests.integration_tests.cache.fake_embeddings import (
AngularTwoDimensionalEmbeddings,
FakeEmbeddings,
)
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}")
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}")
store = CassandraGraphVectorStore.from_documents(
documents,
embedding=embedding_class(),
session=session,
keyspace=keyspace,
node_table=node_table,
targets_table=edge_table,
)
return store
TEST_KEYSPACE = "graph_test_keyspace"
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_query(self, text: str) -> List[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
def embed_query(self, text: str) -> list[float]:
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]
vals = json.loads(text)
except json.JSONDecodeError:
return [0.0] * self.dimension
else:
assert len(vals) == self.dimension
return vals
@pytest.fixture
def embedding_d2() -> Embeddings:
return ParserEmbeddings(dimension=2)
class EarthEmbeddings(Embeddings):
def get_vector_near(self, value: float) -> List[float]:
base_point = [value, (1 - value**2) ** 0.5]
fluctuation = random.random() / 100.0
return [base_point[0] + fluctuation, base_point[1] - fluctuation]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
def embed_query(self, text: str) -> list[float]:
words = set(text.lower().split())
if "earth" in words:
vector = self.get_vector_near(0.9)
elif {"planet", "world", "globe", "sphere"}.intersection(words):
vector = self.get_vector_near(0.8)
else:
vector = self.get_vector_near(0.1)
return vector
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
return [doc.id for doc in docs]
def test_mmr_traversal() -> None:
@pytest.fixture
def graph_vector_store_docs() -> list[Document]:
"""
Test end to end construction and MMR search.
This is a set of Documents to pre-populate a graph vector store,
with entries placed in a certain way.
Space of the entries (under Euclidean similarity):
A0 (*)
.... AL AR <....
: | :
: | ^ :
v | . v
| :
TR | : BL
T0 --------------x-------------- B0
TL | : BR
| :
| .
| .
|
FL FR
F0
the query point is meant to be at (*).
the A are bidirectionally with B
the A are outgoing to T
the A are incoming from F
The links are like: L with L, 0 with 0 and R with R.
"""
docs_a = [
Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}),
Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}),
Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}),
]
docs_b = [
Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}),
Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}),
Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}),
]
docs_f = [
Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}),
Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}),
Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}),
]
docs_t = [
Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}),
Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}),
Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}),
]
for doc_a, suffix in zip(docs_a, ["l", "0", "r"]):
add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}"))
add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}"))
for doc_b, suffix in zip(docs_b, ["l", "0", "r"]):
add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
for doc_t, suffix in zip(docs_t, ["l", "0", "r"]):
add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}"))
for doc_f, suffix in zip(docs_f, ["l", "0", "r"]):
add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}"))
return docs_a + docs_b + docs_f + docs_t
class CassandraSession:
table_name: str
session: Any
def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session
@contextmanager
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
try:
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
# Yield the session for usage
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
cluster.shutdown()
@pytest.fixture(scope="function")
def graph_vector_store_angular(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=AngularTwoDimensionalEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_earth(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=EarthEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_fake(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=FakeEmbeddings(),
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def graph_vector_store_d2(
embedding_d2: Embeddings,
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=embedding_d2,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=session.table_name,
)
@pytest.fixture(scope="function")
def populated_graph_vector_store_d2(
graph_vector_store_d2: CassandraGraphVectorStore,
graph_vector_store_docs: list[Document],
) -> Generator[CassandraGraphVectorStore, None, None]:
graph_vector_store_d2.add_documents(graph_vector_store_docs)
yield graph_vector_store_d2
def test_mmr_traversal(graph_vector_store_angular: CassandraGraphVectorStore) -> None:
""" Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
@ -128,140 +267,128 @@ def test_mmr_traversal() -> None:
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
store = _get_graph_store(AngularTwoDimensionalEmbeddings)
v0 = Document(
v0 = Node(
id="v0",
page_content="-0.124",
metadata={
METADATA_LINKS_KEY: [
text="-0.124",
links=[
Link.outgoing(kind="explicit", tag="link"),
],
},
)
v1 = Document(
v1 = Node(
id="v1",
page_content="+0.127",
text="+0.127",
)
v2 = Document(
v2 = Node(
id="v2",
page_content="+0.25",
metadata={
METADATA_LINKS_KEY: [
text="+0.25",
links=[
Link.incoming(kind="explicit", tag="link"),
],
},
)
v3 = Document(
v3 = Node(
id="v3",
page_content="+1.0",
metadata={
METADATA_LINKS_KEY: [
text="+1.0",
links=[
Link.incoming(kind="explicit", tag="link"),
],
},
)
store.add_documents([v0, v1, v2, v3])
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
g_store = graph_vector_store_angular
g_store.add_nodes([v0, v1, v2, v3])
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]
# With max depth 0 but higher `fetch_k`, we encounter v2
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]
# with k=4 we should get all of the documents.
results = store.mmr_traversal_search("0.0", k=4)
results = g_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
def test_write_retrieve_keywords() -> None:
from langchain_openai import OpenAIEmbeddings
greetings = Document(
def test_write_retrieve_keywords(
graph_vector_store_earth: CassandraGraphVectorStore,
) -> None:
greetings = Node(
id="greetings",
page_content="Typical Greetings",
metadata={
METADATA_LINKS_KEY: [
text="Typical Greetings",
links=[
Link.incoming(kind="parent", tag="parent"),
],
},
)
doc1 = Document(
node1 = Node(
id="doc1",
page_content="Hello World",
metadata={
METADATA_LINKS_KEY: [
text="Hello World",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="world"),
],
},
)
doc2 = Document(
node2 = Node(
id="doc2",
page_content="Hello Earth",
metadata={
METADATA_LINKS_KEY: [
text="Hello Earth",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="earth"),
],
},
)
store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2])
g_store = graph_vector_store_earth
g_store.add_nodes(nodes=[greetings, node1, node2])
# Doc2 is more similar, but World and Earth are similar enough that doc1 also
# shows up.
results: Iterable[Document] = store.similarity_search("Earth", k=2)
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
assert _result_ids(results) == ["doc2", "doc1"]
results = store.similarity_search("Earth", k=1)
results = g_store.similarity_search("Earth", k=1)
assert _result_ids(results) == ["doc2"]
results = store.traversal_search("Earth", k=2, depth=0)
results = g_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]
results = store.traversal_search("Earth", k=2, depth=1)
results = g_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
# K=1 only pulls in doc2 (Hello Earth)
results = store.traversal_search("Earth", k=1, depth=0)
results = g_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
# keyword edge.
results = store.traversal_search("Earth", k=1, depth=1)
results = g_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
def test_metadata() -> None:
store = _get_graph_store(FakeEmbeddings)
store.add_documents(
[
Document(
def test_metadata(graph_vector_store_fake: CassandraGraphVectorStore) -> None:
doc_a = Node(
id="a",
page_content="A",
metadata={
METADATA_LINKS_KEY: [
text="A",
metadata={"other": "some other field"},
links=[
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
],
"other": "some other field",
},
)
]
)
results = store.similarity_search("A")
g_store = graph_vector_store_fake
g_store.add_nodes([doc_a])
results = g_store.similarity_search("A")
assert len(results) == 1
assert results[0].id == "a"
metadata = results[0].metadata
@ -270,3 +397,274 @@ def test_metadata() -> None:
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
}
class TestCassandraGraphVectorStore:
def test_gvs_similarity_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Simple (non-graph) similarity search on a graph vector g_store."""
g_store = populated_graph_vector_store_d2
ss_response = g_store.similarity_search(query="[2, 10]", k=2)
ss_labels = [doc.metadata["label"] for doc in ss_response]
assert ss_labels == ["AR", "A0"]
ss_by_v_response = g_store.similarity_search_by_vector(embedding=[2, 10], k=2)
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
assert ss_by_v_labels == ["AR", "A0"]
async def test_gvs_similarity_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Simple (non-graph) similarity search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ss_response = await g_store.asimilarity_search(query="[2, 10]", k=2)
ss_labels = [doc.metadata["label"] for doc in ss_response]
assert ss_labels == ["AR", "A0"]
ss_by_v_response = await g_store.asimilarity_search_by_vector(
embedding=[2, 10], k=2
)
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
assert ss_by_v_labels == ["AR", "A0"]
def test_gvs_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ts_response = g_store.traversal_search(query="[2, 10]", k=2, depth=2)
# this is a set, as some of the internals of trav.search are set-driven
# so ordering is not deterministic:
ts_labels = {doc.metadata["label"] for doc in ts_response}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
async def test_gvs_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
ts_labels = set()
async for doc in g_store.atraversal_search(query="[2, 10]", k=2, depth=2):
ts_labels.add(doc.metadata["label"])
# this is a set, as some of the internals of trav.search are set-driven
# so ordering is not deterministic:
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
def test_gvs_mmr_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = g_store.mmr_traversal_search(
query="[2, 10]",
k=2,
depth=2,
fetch_k=1,
adjacent_k=2,
lambda_mult=0.1,
)
# TODO: can this rightfully be a list (or must it be a set)?
mt_labels = {doc.metadata["label"] for doc in mt_response}
assert mt_labels == {"AR", "BR"}
async def test_gvs_mmr_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_labels = set()
async for doc in g_store.ammr_traversal_search(
query="[2, 10]",
k=2,
depth=2,
fetch_k=1,
adjacent_k=2,
lambda_mult=0.1,
):
mt_labels.add(doc.metadata["label"])
# TODO: can this rightfully be a list (or must it be a set)?
assert mt_labels == {"AR", "BR"}
def test_gvs_metadata_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Metadata search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = g_store.metadata_search(
filter={"label": "T0"},
n=2,
)
doc: Document = next(iter(mt_response))
assert doc.page_content == "[-10, 0]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "in"
assert link.kind == "at_example"
assert link.tag == "tag_0"
async def test_gvs_metadata_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Metadata search on a graph vector store."""
g_store = populated_graph_vector_store_d2
mt_response = await g_store.ametadata_search(
filter={"label": "T0"},
n=2,
)
doc: Document = next(iter(mt_response))
assert doc.page_content == "[-10, 0]"
links: set[Link] = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "in"
assert link.kind == "at_example"
assert link.tag == "tag_0"
def test_gvs_get_by_document_id_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Get by document_id on a graph vector store."""
g_store = populated_graph_vector_store_d2
doc = g_store.get_by_document_id(document_id="FL")
assert doc is not None
assert doc.page_content == "[1, -9]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "out"
assert link.kind == "af_example"
assert link.tag == "tag_l"
invalid_doc = g_store.get_by_document_id(document_id="invalid")
assert invalid_doc is None
async def test_gvs_get_by_document_id_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
"""Get by document_id on a graph vector store."""
g_store = populated_graph_vector_store_d2
doc = await g_store.aget_by_document_id(document_id="FL")
assert doc is not None
assert doc.page_content == "[1, -9]"
links = doc.metadata["links"]
assert len(links) == 1
link: Link = links.pop()
assert isinstance(link, Link)
assert link.direction == "out"
assert link.kind == "af_example"
assert link.tag == "tag_l"
invalid_doc = await g_store.aget_by_document_id(document_id="invalid")
assert invalid_doc is None
def test_gvs_from_texts(
self,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
g_store = graph_vector_store_d2
g_store.add_texts(
texts=["[1, 2]"],
metadatas=[{"md": 1}],
ids=["x_id"],
)
hits = g_store.similarity_search("[2, 1]", k=2)
assert len(hits) == 1
assert hits[0].page_content == "[1, 2]"
assert hits[0].id == "x_id"
# there may be more re:graph structure.
assert hits[0].metadata["md"] == "1.0"
def test_gvs_from_documents_containing_ids(
self,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
the_document = Document(
page_content="[1, 2]",
metadata={"md": 1},
id="x_id",
)
g_store = graph_vector_store_d2
g_store.add_documents([the_document])
hits = g_store.similarity_search("[2, 1]", k=2)
assert len(hits) == 1
assert hits[0].page_content == "[1, 2]"
assert hits[0].id == "x_id"
# there may be more re:graph structure.
assert hits[0].metadata["md"] == "1.0"
def test_gvs_add_nodes_sync(
self,
*,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
]
links1 = [
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
graph_vector_store_d2.add_nodes(nodes)
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
async def test_gvs_add_nodes_async(
self,
*,
graph_vector_store_d2: CassandraGraphVectorStore,
) -> None:
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
]
links1 = [
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
pass
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")

View File

@ -0,0 +1,269 @@
"""Test of Upgrading to Apache Cassandra graph vector store class:
`CassandraGraphVectorStore` from an existing table used
by the Cassandra vector store class: `Cassandra`
"""
from __future__ import annotations
import json
import os
from contextlib import contextmanager
from typing import Any, Generator, Iterable, Optional, Tuple, Union
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.utilities.cassandra import SetupMode
from langchain_community.vectorstores import Cassandra
TEST_KEYSPACE = "graph_test_keyspace"
TABLE_NAME_ALLOW_INDEXING = "allow_graph_table"
TABLE_NAME_DEFAULT = "default_graph_table"
TABLE_NAME_DENY_INDEXING = "deny_graph_table"
class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
"""
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> list[float]:
try:
vals = json.loads(text)
except json.JSONDecodeError:
return [0.0] * self.dimension
else:
assert len(vals) == self.dimension
return vals
async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)
@pytest.fixture
def embedding_d2() -> Embeddings:
return ParserEmbeddings(dimension=2)
class CassandraSession:
table_name: str
session: Any
def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session
@contextmanager
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
try:
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
# Yield the session for usage
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
cluster.shutdown()
@contextmanager
def vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
drop: bool = True,
) -> Generator[Cassandra, None, None]:
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield Cassandra(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_indexing=metadata_indexing,
)
@contextmanager
def graph_vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_deny_list: Optional[list[str]] = None,
drop: bool = True,
) -> Generator[CassandraGraphVectorStore, None, None]:
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield CassandraGraphVectorStore(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_deny_list=metadata_deny_list,
)
def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]:
if table_name == TABLE_NAME_ALLOW_INDEXING:
return ("allowlist", ["test"])
if table_name == TABLE_NAME_DEFAULT:
return "all"
if table_name == TABLE_NAME_DENY_INDEXING:
return ("denylist", ["test"])
msg = f"Unknown table_name: {table_name} in _vs_indexing_policy()"
raise ValueError(msg)
class TestUpgradeToGraphVectorStore:
@pytest.mark.parametrize(
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
[
(TABLE_NAME_DEFAULT, SetupMode.SYNC, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.SYNC, ["test"]),
(TABLE_NAME_DEFAULT, SetupMode.OFF, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, ["test"]),
# for this one, even though the passed policy doesn't
# match the policy used to create the collection,
# there is no error since the SetupMode is OFF and
# and no attempt is made to re-create the collection.
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, None),
],
ids=[
"default_upgrade_no_policy_sync",
"deny_list_upgrade_same_policy_sync",
"default_upgrade_no_policy_off",
"deny_list_upgrade_same_policy_off",
"deny_list_upgrade_change_policy_off",
],
)
def test_upgrade_to_gvs_success_sync(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
) -> None:
doc_id = "AL"
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
# Create vector store using SetupMode.SYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.SYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
drop=True,
) as v_store:
# load a document to the vector store
v_store.add_documents([doc_al])
# get the document from the vector store
v_doc = v_store.get_by_document_id(document_id=doc_id)
assert v_doc is not None
assert v_doc.page_content == doc_al.page_content
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
drop=False,
) as gv_store:
# get the document from the GRAPH vector store
gv_doc = gv_store.get_by_document_id(document_id=doc_id)
assert gv_doc is not None
assert gv_doc.page_content == doc_al.page_content
@pytest.mark.parametrize(
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
[
(TABLE_NAME_DEFAULT, SetupMode.ASYNC, None),
(TABLE_NAME_DENY_INDEXING, SetupMode.ASYNC, ["test"]),
],
ids=[
"default_upgrade_no_policy_async",
"deny_list_upgrade_same_policy_async",
],
)
async def test_upgrade_to_gvs_success_async(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
) -> None:
doc_id = "AL"
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
# Create vector store using SetupMode.ASYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.ASYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
drop=True,
) as v_store:
# load a document to the vector store
await v_store.aadd_documents([doc_al])
# get the document from the vector store
v_doc = await v_store.aget_by_document_id(document_id=doc_id)
assert v_doc is not None
assert v_doc.page_content == doc_al.page_content
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
drop=False,
) as gv_store:
# get the document from the GRAPH vector store
gv_doc = await gv_store.aget_by_document_id(document_id=doc_id)
assert gv_doc is not None
assert gv_doc.page_content == doc_al.page_content

View File

@ -0,0 +1,143 @@
from __future__ import annotations
import math
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
IDS = {
"-1",
"-2",
"-3",
"-4",
"-5",
"+1",
"+2",
"+3",
"+4",
"+5",
}
class TestMmrHelper:
def test_mmr_helper_functional(self) -> None:
helper = MmrHelper(k=3, query_embedding=[6, 5], lambda_mult=0.5)
assert len(list(helper.candidate_ids())) == 0
helper.add_candidates({"-1": [3, 5]})
helper.add_candidates({"-2": [3, 5]})
helper.add_candidates({"-3": [2, 6]})
helper.add_candidates({"-4": [1, 6]})
helper.add_candidates({"-5": [0, 6]})
assert len(list(helper.candidate_ids())) == 5
helper.add_candidates({"+1": [5, 3]})
helper.add_candidates({"+2": [5, 3]})
helper.add_candidates({"+3": [6, 2]})
helper.add_candidates({"+4": [6, 1]})
helper.add_candidates({"+5": [6, 0]})
assert len(list(helper.candidate_ids())) == 10
for idx in range(3):
best_id = helper.pop_best()
assert best_id in IDS
assert len(list(helper.candidate_ids())) == 9 - idx
assert best_id not in helper.candidate_ids()
def test_mmr_helper_max_diversity(self) -> None:
helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=0)
helper.add_candidates({"-1": [3, 5]})
helper.add_candidates({"-2": [3, 5]})
helper.add_candidates({"-3": [2, 6]})
helper.add_candidates({"-4": [1, 6]})
helper.add_candidates({"-5": [0, 6]})
best = {helper.pop_best(), helper.pop_best()}
assert best == {"-1", "-5"}
def test_mmr_helper_max_similarity(self) -> None:
helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=1)
helper.add_candidates({"-1": [3, 5]})
helper.add_candidates({"-2": [3, 5]})
helper.add_candidates({"-3": [2, 6]})
helper.add_candidates({"-4": [1, 6]})
helper.add_candidates({"-5": [0, 6]})
best = {helper.pop_best(), helper.pop_best()}
assert best == {"-1", "-2"}
def test_mmr_helper_add_candidate(self) -> None:
helper = MmrHelper(5, [0.0, 1.0])
helper.add_candidates(
{
"a": [0.0, 1.0],
"b": [1.0, 0.0],
}
)
assert helper.best_id == "a"
def test_mmr_helper_pop_best(self) -> None:
helper = MmrHelper(5, [0.0, 1.0])
helper.add_candidates(
{
"a": [0.0, 1.0],
"b": [1.0, 0.0],
}
)
assert helper.pop_best() == "a"
assert helper.pop_best() == "b"
assert helper.pop_best() is None
def angular_embedding(self, angle: float) -> list[float]:
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
def test_mmr_helper_added_documents(self) -> None:
"""Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==2 and k==2, when query is at 0.0, (1, ),
one expects that v2 and v0 are returned (in some order)
because v1 is "too close" to v0 (and v0 is closer than v1)).
Both v2 and v3 are discovered after v0.
"""
helper = MmrHelper(5, self.angular_embedding(0.0))
# Fetching the 2 nearest neighbors to 0.0
helper.add_candidates(
{
"v0": self.angular_embedding(-0.124),
"v1": self.angular_embedding(+0.127),
}
)
assert helper.pop_best() == "v0"
# After v0 is selected, new nodes are discovered.
# v2 is closer than v3. v1 is "too similar" to "v0" so it's not included.
helper.add_candidates(
{
"v2": self.angular_embedding(+0.25),
"v3": self.angular_embedding(+1.0),
}
)
assert helper.pop_best() == "v2"
assert math.isclose(
helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001
)
assert math.isclose(
helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001
)
assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001)
assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001)