community: Cassandra Vector Store: modernize implementation (#27253)

**Description:** 

This PR updates `CassandraGraphVectorStore` to be based off
`CassandraVectorStore`, instead of using a custom CQL implementation.
This allows users using a `CassandraVectorStore` to upgrade to a
`GraphVectorStore` without having to change their database schema or
re-embed documents.

This PR also updates the documentation of the `GraphVectorStore` base
class and contains native async implementations for the standard graph
methods: `traversal_search` and `mmr_traversal_search` in
`CassandraVectorStore`.

**Issue:** No issue number.

**Dependencies:** https://github.com/langchain-ai/langchain/pull/27078
(already-merged)

**Lint and test**: 
- Lint and tests all pass, including existing
`CassandraGraphVectorStore` tests.
- Also added numerous additional tests based of the tests in
`langchain-astradb` which cover many more scenarios than the existing
tests for `Cassandra` and `CassandraGraphVectorStore`

** BREAKING CHANGE**

Note that this is a breaking change for existing users of
`CassandraGraphVectorStore`. They will need to wipe their database table
and restart.

However:
- The interfaces have not changed. Just the underlying storage
mechanism.
- Any one using `langchain_community.vectorstores.Cassandra` can instead
use `langchain_community.graph_vectorstores.CassandraGraphVectorStore`
and they will gain Graph capabilities without having to re-embed their
existing documents. This is the primary goal of this PR.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Eric Pinzur
2024-10-22 20:11:11 +02:00
committed by GitHub
parent 0640cbf2f1
commit f636c83321
9 changed files with 4070 additions and 679 deletions

View File

@@ -144,6 +144,7 @@ from langchain_community.graph_vectorstores.cassandra import CassandraGraphVecto
from langchain_community.graph_vectorstores.links import (
Link,
)
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
__all__ = [
"GraphVectorStore",
@@ -151,4 +152,5 @@ __all__ = [
"Node",
"Link",
"CassandraGraphVectorStore",
"MmrHelper",
]

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
import logging
from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import (
Any,
ClassVar,
Optional,
Sequence,
)
from langchain_core._api import beta
@@ -21,6 +23,8 @@ from pydantic import Field
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
logger = logging.getLogger(__name__)
def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
@@ -158,6 +162,7 @@ class GraphVectorStore(VectorStore):
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
async def aadd_nodes(
@@ -169,6 +174,7 @@ class GraphVectorStore(VectorStore):
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
done = object()
@@ -186,7 +192,7 @@ class GraphVectorStore(VectorStore):
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
"""Run more texts through the embeddings and add to the vector store.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
@@ -214,15 +220,15 @@ class GraphVectorStore(VectorStore):
)
Args:
texts: Iterable of strings to add to the vectorstore.
texts: Iterable of strings to add to the vector store.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
ids: Optional list of IDs associated with the texts.
**kwargs: vectorstore specific parameters.
**kwargs: vector store specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
List of ids from adding the texts into the vector store.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return list(self.add_nodes(nodes, **kwargs))
@@ -235,7 +241,7 @@ class GraphVectorStore(VectorStore):
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
"""Run more texts through the embeddings and add to the vector store.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
@@ -263,15 +269,15 @@ class GraphVectorStore(VectorStore):
)
Args:
texts: Iterable of strings to add to the vectorstore.
texts: Iterable of strings to add to the vector store.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
ids: Optional list of IDs associated with the texts.
**kwargs: vectorstore specific parameters.
**kwargs: vector store specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
List of ids from adding the texts into the vector store.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
@@ -281,7 +287,7 @@ class GraphVectorStore(VectorStore):
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
"""Run more documents through the embeddings and add to the vector store.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
@@ -316,7 +322,7 @@ class GraphVectorStore(VectorStore):
)
Args:
documents: Documents to add to the vectorstore.
documents: Documents to add to the vector store.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
@@ -331,7 +337,7 @@ class GraphVectorStore(VectorStore):
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
"""Run more documents through the embeddings and add to the vector store.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
@@ -366,7 +372,7 @@ class GraphVectorStore(VectorStore):
)
Args:
documents: Documents to add to the vectorstore.
documents: Documents to add to the vector store.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
@@ -383,6 +389,7 @@ class GraphVectorStore(VectorStore):
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from traversing this graph store.
@@ -396,8 +403,10 @@ class GraphVectorStore(VectorStore):
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Retrieved documents.
Collection of retrieved documents.
"""
async def atraversal_search(
@@ -406,6 +415,7 @@ class GraphVectorStore(VectorStore):
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from traversing this graph store.
@@ -419,12 +429,20 @@ class GraphVectorStore(VectorStore):
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Retrieved documents.
Collection of retrieved documents.
"""
iterator = iter(
await run_in_executor(
None, self.traversal_search, query, k=k, depth=depth, **kwargs
None,
self.traversal_search,
query,
k=k,
depth=depth,
filter=filter,
**kwargs,
)
)
done = object()
@@ -439,12 +457,14 @@ class GraphVectorStore(VectorStore):
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
@@ -459,6 +479,10 @@ class GraphVectorStore(VectorStore):
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
@@ -471,18 +495,22 @@ class GraphVectorStore(VectorStore):
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
async def ammr_traversal_search(
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None, # noqa: A002
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
@@ -497,6 +525,10 @@ class GraphVectorStore(VectorStore):
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
@@ -509,18 +541,22 @@ class GraphVectorStore(VectorStore):
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
iterator = iter(
await run_in_executor(
None,
self.mmr_traversal_search,
query,
initial_roots=initial_roots,
k=k,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
depth=depth,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
filter=filter,
**kwargs,
)
)
@@ -544,6 +580,11 @@ class GraphVectorStore(VectorStore):
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
if kwargs.get("depth", 0) > 0:
logger.warning(
"'mmr' search started with depth > 0. "
"Maybe you meant to do a 'mmr_traversal' search?"
)
return list(
self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
@@ -573,7 +614,7 @@ class GraphVectorStore(VectorStore):
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
"'mmr', 'traversal', or 'mmr_traversal'."
)
async def asearch(
@@ -590,11 +631,13 @@ class GraphVectorStore(VectorStore):
return await self.amax_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return [doc async for doc in self.atraversal_search(query, **kwargs)]
elif search_type == "mmr_traversal":
return [doc async for doc in self.ammr_traversal_search(query, **kwargs)]
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
"'mmr', 'traversal', or 'mmr_traversal'."
)
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
@@ -606,13 +649,14 @@ class GraphVectorStore(VectorStore):
- search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
``similarity_score_threshold``.
Can be ``traversal`` (default), ``similarity``, ``mmr``,
``mmr_traversal``, or ``similarity_score_threshold``.
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
- k(int): Amount of documents to return (Default: 4).
- depth(int): The maximum depth of edges to traverse (Default: 1).
Only applies to search_type: ``traversal`` and ``mmr_traversal``.
- score_threshold(float): Minimum relevance threshold
for similarity_score_threshold.
- fetch_k(int): Amount of documents to pass to MMR algorithm
@@ -629,21 +673,21 @@ class GraphVectorStore(VectorStore):
# Retrieve documents traversing edges
docsearch.as_retriever(
search_type="traversal",
search_kwargs={'k': 6, 'depth': 3}
search_kwargs={'k': 6, 'depth': 2}
)
# Retrieve more documents with higher diversity
# Retrieve documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
search_type="mmr_traversal",
search_kwargs={'k': 6, 'lambda_mult': 0.25, 'depth': 2}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
search_type="mmr_traversal",
search_kwargs={'k': 5, 'fetch_k': 50, 'depth': 2}
)
# Only retrieve documents that have a relevance score
@@ -657,7 +701,7 @@ class GraphVectorStore(VectorStore):
docsearch.as_retriever(search_kwargs={'k': 1})
"""
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
@@ -744,7 +788,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
Passing search parameters
-------------------------
We can pass parameters to the underlying graph vectorstore's search methods using
We can pass parameters to the underlying graph vector store's search methods using
``search_kwargs``.
Specifying graph traversal depth
@@ -793,7 +837,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
""" # noqa: E501
vectorstore: GraphVectorStore
vector_store: GraphVectorStore
"""GraphVectorStore to use for retrieval."""
search_type: str = "traversal"
"""Type of search to perform. Defaults to "traversal"."""
@@ -809,10 +853,10 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal":
return list(
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
)
else:
return super()._get_relevant_documents(query, run_manager=run_manager)
@@ -823,14 +867,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
if self.search_type == "traversal":
return [
doc
async for doc in self.vectorstore.atraversal_search(
async for doc in self.vector_store.atraversal_search(
query, **self.search_kwargs
)
]
elif self.search_type == "mmr_traversal":
return [
doc
async for doc in self.vectorstore.ammr_traversal_search(
async for doc in self.vector_store.ammr_traversal_search(
query, **self.search_kwargs
)
]

View File

@@ -0,0 +1,272 @@
"""Tools for the Graph Traversal Maximal Marginal Relevance (MMR) reranking."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, Iterable
import numpy as np
from langchain_community.utils.math import cosine_similarity
if TYPE_CHECKING:
from numpy.typing import NDArray
def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
emb_array = np.array(embedding, dtype=np.float32)
if emb_array.ndim == 1:
emb_array = np.expand_dims(emb_array, axis=0)
return emb_array
NEG_INF = float("-inf")
@dataclasses.dataclass
class _Candidate:
id: str
similarity: float
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
def __post_init__(self) -> None:
self.score = self.weighted_similarity - self.weighted_redundancy
def update_redundancy(self, new_weighted_redundancy: float) -> None:
if new_weighted_redundancy > self.weighted_redundancy:
self.weighted_redundancy = new_weighted_redundancy
self.score = self.weighted_similarity - self.weighted_redundancy
class MmrHelper:
"""Helper for executing an MMR traversal query.
Args:
query_embedding: The embedding of the query to use for scoring.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
"""
dimensions: int
"""Dimensions of the embedding."""
query_embedding: NDArray[np.float32]
"""Embedding of the query as a (1,dim) ndarray."""
lambda_mult: float
"""Number between 0 and 1.
Determines the degree of diversity among the results with 0 corresponding to
maximum diversity and 1 to minimum diversity."""
lambda_mult_complement: float
"""1 - lambda_mult."""
score_threshold: float
"""Only documents with a score greater than or equal to this will be chosen."""
selected_ids: list[str]
"""List of selected IDs (in selection order)."""
selected_mmr_scores: list[float]
"""List of MMR score at the time each document is selected."""
selected_similarity_scores: list[float]
"""List of similarity score for each selected document."""
selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""
candidate_id_to_index: dict[str, int]
"""Dictionary of candidate IDs to indices in candidates and candidate_embeddings."""
candidates: list[_Candidate]
"""List containing information about candidates.
Same order as rows in `candidate_embeddings`.
"""
candidate_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each candidate."""
best_score: float
best_id: str | None
def __init__(
self,
k: int,
query_embedding: list[float],
lambda_mult: float = 0.5,
score_threshold: float = NEG_INF,
) -> None:
"""Create a new Traversal MMR helper."""
self.query_embedding = _emb_to_ndarray(query_embedding)
self.dimensions = self.query_embedding.shape[1]
self.lambda_mult = lambda_mult
self.lambda_mult_complement = 1 - lambda_mult
self.score_threshold = score_threshold
self.selected_ids = []
self.selected_similarity_scores = []
self.selected_mmr_scores = []
# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
self.candidate_id_to_index = {}
# List of the candidates.
self.candidates = []
# numpy n-dimensional array of the candidate embeddings.
self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32)
self.best_score = NEG_INF
self.best_id = None
def candidate_ids(self) -> Iterable[str]:
"""Return the IDs of the candidates."""
return self.candidate_id_to_index.keys()
def _already_selected_embeddings(self) -> NDArray[np.float32]:
"""Return the selected embeddings sliced to the already assigned values."""
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
"""Pop the candidate with the given ID.
Returns:
The similarity score and embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
if self.candidates[index].id != candidate_id:
msg = (
"ID in self.candidate_id_to_index doesn't match the ID of the "
"corresponding index in self.candidates"
)
raise ValueError(msg)
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
# Swap that index with the last index in the candidates and
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1
similarity = 0.0
if index == last_index:
# Already the last item. We don't need to swap.
similarity = self.candidates.pop().similarity
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
similarity = self.candidates[index].similarity
old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[
0
]
return similarity, embedding
def pop_best(self) -> str | None:
"""Select and pop the best item being considered.
Updates the consideration set based on it.
Returns:
A tuple containing the ID of the best item.
"""
if self.best_id is None or self.best_score < self.score_threshold:
return None
# Get the selection and remove from candidates.
selected_id = self.best_id
selected_similarity, selected_embedding = self._pop_candidate(selected_id)
# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_mmr_scores.append(self.best_score)
self.selected_similarity_scores.append(selected_similarity)
self.selected_embeddings[selection_index] = selected_embedding
# Reset the best score / best ID.
self.best_score = NEG_INF
self.best_id = None
# Update the candidates redundancy, tracking the best node.
if self.candidate_embeddings.shape[0] > 0:
similarity = cosine_similarity(
self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0)
)
for index, candidate in enumerate(self.candidates):
candidate.update_redundancy(similarity[index][0])
if candidate.score > self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
return selected_id
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.
include_ids_set = set(candidates.keys())
include_ids_set.difference_update(self.selected_ids)
include_ids_set.difference_update(self.candidate_id_to_index.keys())
include_ids = list(include_ids_set)
# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(
len(include_ids),
self.dimensions,
)
)
offset = self.candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
if candidate_id in include_ids:
self.candidate_id_to_index[candidate_id] = offset + index
embedding = candidates[candidate_id]
new_embeddings[index] = embedding
# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self.query_embedding)
# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
self.candidates.append(candidate)
if candidate.score >= self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
# Add the new embeddings to the candidate set.
self.candidate_embeddings = np.vstack(
(
self.candidate_embeddings,
new_embeddings,
)
)