mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
community: Cassandra Vector Store: modernize implementation (#27253)
**Description:** This PR updates `CassandraGraphVectorStore` to be based off `CassandraVectorStore`, instead of using a custom CQL implementation. This allows users using a `CassandraVectorStore` to upgrade to a `GraphVectorStore` without having to change their database schema or re-embed documents. This PR also updates the documentation of the `GraphVectorStore` base class and contains native async implementations for the standard graph methods: `traversal_search` and `mmr_traversal_search` in `CassandraVectorStore`. **Issue:** No issue number. **Dependencies:** https://github.com/langchain-ai/langchain/pull/27078 (already-merged) **Lint and test**: - Lint and tests all pass, including existing `CassandraGraphVectorStore` tests. - Also added numerous additional tests based of the tests in `langchain-astradb` which cover many more scenarios than the existing tests for `Cassandra` and `CassandraGraphVectorStore` ** BREAKING CHANGE** Note that this is a breaking change for existing users of `CassandraGraphVectorStore`. They will need to wipe their database table and restart. However: - The interfaces have not changed. Just the underlying storage mechanism. - Any one using `langchain_community.vectorstores.Cassandra` can instead use `langchain_community.graph_vectorstores.CassandraGraphVectorStore` and they will gain Graph capabilities without having to re-embed their existing documents. This is the primary goal of this PR. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
0640cbf2f1
commit
f636c83321
@ -144,6 +144,7 @@ from langchain_community.graph_vectorstores.cassandra import CassandraGraphVecto
|
||||
from langchain_community.graph_vectorstores.links import (
|
||||
Link,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
|
||||
|
||||
__all__ = [
|
||||
"GraphVectorStore",
|
||||
@ -151,4 +152,5 @@ __all__ = [
|
||||
"Node",
|
||||
"Link",
|
||||
"CassandraGraphVectorStore",
|
||||
"MmrHelper",
|
||||
]
|
||||
|
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
@ -21,6 +23,8 @@ from pydantic import Field
|
||||
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
@ -158,6 +162,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
@ -169,6 +174,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
@ -186,7 +192,7 @@ class GraphVectorStore(VectorStore):
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
"""Run more texts through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
@ -214,15 +220,15 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
texts: Iterable of strings to add to the vector store.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
**kwargs: vector store specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
List of ids from adding the texts into the vector store.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
@ -235,7 +241,7 @@ class GraphVectorStore(VectorStore):
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
"""Run more texts through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
@ -263,15 +269,15 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
texts: Iterable of strings to add to the vector store.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
**kwargs: vector store specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
List of ids from adding the texts into the vector store.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
@ -281,7 +287,7 @@ class GraphVectorStore(VectorStore):
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
"""Run more documents through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
@ -316,7 +322,7 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
documents: Documents to add to the vector store.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
@ -331,7 +337,7 @@ class GraphVectorStore(VectorStore):
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
"""Run more documents through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
@ -366,7 +372,7 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
documents: Documents to add to the vector store.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
@ -383,6 +389,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
@ -396,8 +403,10 @@ class GraphVectorStore(VectorStore):
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
Collection of retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
@ -406,6 +415,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
@ -419,12 +429,20 @@ class GraphVectorStore(VectorStore):
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
Collection of retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
None,
|
||||
self.traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
depth=depth,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
@ -439,12 +457,14 @@ class GraphVectorStore(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
initial_roots: Sequence[str] = (),
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
@ -459,6 +479,10 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
initial_roots: Optional list of document IDs to use for initializing search.
|
||||
The top `adjacent_k` nodes adjacent to each initial root will be
|
||||
included in the set of initial candidates. To fetch only in the
|
||||
neighborhood of these nodes, set `fetch_k = 0`.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
@ -471,18 +495,22 @@ class GraphVectorStore(VectorStore):
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
initial_roots: Sequence[str] = (),
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
@ -497,6 +525,10 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
initial_roots: Optional list of document IDs to use for initializing search.
|
||||
The top `adjacent_k` nodes adjacent to each initial root will be
|
||||
included in the set of initial candidates. To fetch only in the
|
||||
neighborhood of these nodes, set `fetch_k = 0`.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
@ -509,18 +541,22 @@ class GraphVectorStore(VectorStore):
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
initial_roots=initial_roots,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@ -544,6 +580,11 @@ class GraphVectorStore(VectorStore):
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
if kwargs.get("depth", 0) > 0:
|
||||
logger.warning(
|
||||
"'mmr' search started with depth > 0. "
|
||||
"Maybe you meant to do a 'mmr_traversal' search?"
|
||||
)
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
@ -573,7 +614,7 @@ class GraphVectorStore(VectorStore):
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
"'mmr', 'traversal', or 'mmr_traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
@ -590,11 +631,13 @@ class GraphVectorStore(VectorStore):
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
elif search_type == "mmr_traversal":
|
||||
return [doc async for doc in self.ammr_traversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
"'mmr', 'traversal', or 'mmr_traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
@ -606,13 +649,14 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``,
|
||||
``mmr_traversal``, or ``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
Only applies to search_type: ``traversal`` and ``mmr_traversal``.
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
@ -629,21 +673,21 @@ class GraphVectorStore(VectorStore):
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
search_kwargs={'k': 6, 'depth': 2}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Retrieve documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
search_type="mmr_traversal",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25, 'depth': 2}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
search_type="mmr_traversal",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50, 'depth': 2}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
@ -657,7 +701,7 @@ class GraphVectorStore(VectorStore):
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
|
||||
@ -744,7 +788,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
Passing search parameters
|
||||
-------------------------
|
||||
|
||||
We can pass parameters to the underlying graph vectorstore's search methods using
|
||||
We can pass parameters to the underlying graph vector store's search methods using
|
||||
``search_kwargs``.
|
||||
|
||||
Specifying graph traversal depth
|
||||
@ -793,7 +837,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
|
||||
""" # noqa: E501
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
vector_store: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
@ -809,10 +853,10 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
@ -823,14 +867,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
async for doc in self.vector_store.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
async for doc in self.vector_store.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,272 @@
|
||||
"""Tools for the Graph Traversal Maximal Marginal Relevance (MMR) reranking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain_community.utils.math import cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
|
||||
emb_array = np.array(embedding, dtype=np.float32)
|
||||
if emb_array.ndim == 1:
|
||||
emb_array = np.expand_dims(emb_array, axis=0)
|
||||
return emb_array
|
||||
|
||||
|
||||
NEG_INF = float("-inf")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Candidate:
|
||||
id: str
|
||||
similarity: float
|
||||
weighted_similarity: float
|
||||
weighted_redundancy: float
|
||||
score: float = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.score = self.weighted_similarity - self.weighted_redundancy
|
||||
|
||||
def update_redundancy(self, new_weighted_redundancy: float) -> None:
|
||||
if new_weighted_redundancy > self.weighted_redundancy:
|
||||
self.weighted_redundancy = new_weighted_redundancy
|
||||
self.score = self.weighted_similarity - self.weighted_redundancy
|
||||
|
||||
|
||||
class MmrHelper:
|
||||
"""Helper for executing an MMR traversal query.
|
||||
|
||||
Args:
|
||||
query_embedding: The embedding of the query to use for scoring.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to -infinity.
|
||||
"""
|
||||
|
||||
dimensions: int
|
||||
"""Dimensions of the embedding."""
|
||||
|
||||
query_embedding: NDArray[np.float32]
|
||||
"""Embedding of the query as a (1,dim) ndarray."""
|
||||
|
||||
lambda_mult: float
|
||||
"""Number between 0 and 1.
|
||||
|
||||
Determines the degree of diversity among the results with 0 corresponding to
|
||||
maximum diversity and 1 to minimum diversity."""
|
||||
|
||||
lambda_mult_complement: float
|
||||
"""1 - lambda_mult."""
|
||||
|
||||
score_threshold: float
|
||||
"""Only documents with a score greater than or equal to this will be chosen."""
|
||||
|
||||
selected_ids: list[str]
|
||||
"""List of selected IDs (in selection order)."""
|
||||
|
||||
selected_mmr_scores: list[float]
|
||||
"""List of MMR score at the time each document is selected."""
|
||||
|
||||
selected_similarity_scores: list[float]
|
||||
"""List of similarity score for each selected document."""
|
||||
|
||||
selected_embeddings: NDArray[np.float32]
|
||||
"""(N, dim) ndarray with a row for each selected node."""
|
||||
|
||||
candidate_id_to_index: dict[str, int]
|
||||
"""Dictionary of candidate IDs to indices in candidates and candidate_embeddings."""
|
||||
candidates: list[_Candidate]
|
||||
"""List containing information about candidates.
|
||||
|
||||
Same order as rows in `candidate_embeddings`.
|
||||
"""
|
||||
candidate_embeddings: NDArray[np.float32]
|
||||
"""(N, dim) ndarray with a row for each candidate."""
|
||||
|
||||
best_score: float
|
||||
best_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k: int,
|
||||
query_embedding: list[float],
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = NEG_INF,
|
||||
) -> None:
|
||||
"""Create a new Traversal MMR helper."""
|
||||
self.query_embedding = _emb_to_ndarray(query_embedding)
|
||||
self.dimensions = self.query_embedding.shape[1]
|
||||
|
||||
self.lambda_mult = lambda_mult
|
||||
self.lambda_mult_complement = 1 - lambda_mult
|
||||
self.score_threshold = score_threshold
|
||||
|
||||
self.selected_ids = []
|
||||
self.selected_similarity_scores = []
|
||||
self.selected_mmr_scores = []
|
||||
|
||||
# List of selected embeddings (in selection order).
|
||||
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
|
||||
|
||||
self.candidate_id_to_index = {}
|
||||
|
||||
# List of the candidates.
|
||||
self.candidates = []
|
||||
# numpy n-dimensional array of the candidate embeddings.
|
||||
self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32)
|
||||
|
||||
self.best_score = NEG_INF
|
||||
self.best_id = None
|
||||
|
||||
def candidate_ids(self) -> Iterable[str]:
|
||||
"""Return the IDs of the candidates."""
|
||||
return self.candidate_id_to_index.keys()
|
||||
|
||||
def _already_selected_embeddings(self) -> NDArray[np.float32]:
|
||||
"""Return the selected embeddings sliced to the already assigned values."""
|
||||
selected = len(self.selected_ids)
|
||||
return np.vsplit(self.selected_embeddings, [selected])[0]
|
||||
|
||||
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
|
||||
"""Pop the candidate with the given ID.
|
||||
|
||||
Returns:
|
||||
The similarity score and embedding of the candidate.
|
||||
"""
|
||||
# Get the embedding for the id.
|
||||
index = self.candidate_id_to_index.pop(candidate_id)
|
||||
if self.candidates[index].id != candidate_id:
|
||||
msg = (
|
||||
"ID in self.candidate_id_to_index doesn't match the ID of the "
|
||||
"corresponding index in self.candidates"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
|
||||
|
||||
# Swap that index with the last index in the candidates and
|
||||
# candidate_embeddings.
|
||||
last_index = self.candidate_embeddings.shape[0] - 1
|
||||
|
||||
similarity = 0.0
|
||||
if index == last_index:
|
||||
# Already the last item. We don't need to swap.
|
||||
similarity = self.candidates.pop().similarity
|
||||
else:
|
||||
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
|
||||
|
||||
similarity = self.candidates[index].similarity
|
||||
|
||||
old_last = self.candidates.pop()
|
||||
self.candidates[index] = old_last
|
||||
self.candidate_id_to_index[old_last.id] = index
|
||||
|
||||
self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[
|
||||
0
|
||||
]
|
||||
|
||||
return similarity, embedding
|
||||
|
||||
def pop_best(self) -> str | None:
|
||||
"""Select and pop the best item being considered.
|
||||
|
||||
Updates the consideration set based on it.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ID of the best item.
|
||||
"""
|
||||
if self.best_id is None or self.best_score < self.score_threshold:
|
||||
return None
|
||||
|
||||
# Get the selection and remove from candidates.
|
||||
selected_id = self.best_id
|
||||
selected_similarity, selected_embedding = self._pop_candidate(selected_id)
|
||||
|
||||
# Add the ID and embedding to the selected information.
|
||||
selection_index = len(self.selected_ids)
|
||||
self.selected_ids.append(selected_id)
|
||||
self.selected_mmr_scores.append(self.best_score)
|
||||
self.selected_similarity_scores.append(selected_similarity)
|
||||
self.selected_embeddings[selection_index] = selected_embedding
|
||||
|
||||
# Reset the best score / best ID.
|
||||
self.best_score = NEG_INF
|
||||
self.best_id = None
|
||||
|
||||
# Update the candidates redundancy, tracking the best node.
|
||||
if self.candidate_embeddings.shape[0] > 0:
|
||||
similarity = cosine_similarity(
|
||||
self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0)
|
||||
)
|
||||
for index, candidate in enumerate(self.candidates):
|
||||
candidate.update_redundancy(similarity[index][0])
|
||||
if candidate.score > self.best_score:
|
||||
self.best_score = candidate.score
|
||||
self.best_id = candidate.id
|
||||
|
||||
return selected_id
|
||||
|
||||
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
|
||||
"""Add candidates to the consideration set."""
|
||||
# Determine the keys to actually include.
|
||||
# These are the candidates that aren't already selected
|
||||
# or under consideration.
|
||||
include_ids_set = set(candidates.keys())
|
||||
include_ids_set.difference_update(self.selected_ids)
|
||||
include_ids_set.difference_update(self.candidate_id_to_index.keys())
|
||||
include_ids = list(include_ids_set)
|
||||
|
||||
# Now, build up a matrix of the remaining candidate embeddings.
|
||||
# And add them to the
|
||||
new_embeddings: NDArray[np.float32] = np.ndarray(
|
||||
(
|
||||
len(include_ids),
|
||||
self.dimensions,
|
||||
)
|
||||
)
|
||||
offset = self.candidate_embeddings.shape[0]
|
||||
for index, candidate_id in enumerate(include_ids):
|
||||
if candidate_id in include_ids:
|
||||
self.candidate_id_to_index[candidate_id] = offset + index
|
||||
embedding = candidates[candidate_id]
|
||||
new_embeddings[index] = embedding
|
||||
|
||||
# Compute the similarity to the query.
|
||||
similarity = cosine_similarity(new_embeddings, self.query_embedding)
|
||||
|
||||
# Compute the distance metrics of all of pairs in the selected set with
|
||||
# the new candidates.
|
||||
redundancy = cosine_similarity(
|
||||
new_embeddings, self._already_selected_embeddings()
|
||||
)
|
||||
for index, candidate_id in enumerate(include_ids):
|
||||
max_redundancy = 0.0
|
||||
if redundancy.shape[0] > 0:
|
||||
max_redundancy = redundancy[index].max()
|
||||
candidate = _Candidate(
|
||||
id=candidate_id,
|
||||
similarity=similarity[index][0],
|
||||
weighted_similarity=self.lambda_mult * similarity[index][0],
|
||||
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
|
||||
)
|
||||
self.candidates.append(candidate)
|
||||
|
||||
if candidate.score >= self.best_score:
|
||||
self.best_score = candidate.score
|
||||
self.best_id = candidate.id
|
||||
|
||||
# Add the new embeddings to the candidate set.
|
||||
self.candidate_embeddings = np.vstack(
|
||||
(
|
||||
self.candidate_embeddings,
|
||||
new_embeddings,
|
||||
)
|
||||
)
|
@ -4,6 +4,7 @@ import asyncio
|
||||
import importlib.metadata
|
||||
import typing
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@ -501,10 +502,13 @@ class Cassandra(VectorStore):
|
||||
)
|
||||
|
||||
def get_by_document_id(self, document_id: str) -> Document | None:
|
||||
"""Get by document ID.
|
||||
"""Retrieve a single document from the store, given its document ID.
|
||||
|
||||
Args:
|
||||
document_id: the document ID to get.
|
||||
document_id: The document ID
|
||||
|
||||
Returns:
|
||||
The the document if it exists. Otherwise None.
|
||||
"""
|
||||
row = self.table.get(row_id=document_id)
|
||||
if row is None:
|
||||
@ -512,10 +516,13 @@ class Cassandra(VectorStore):
|
||||
return self._row_to_document(row=row)
|
||||
|
||||
async def aget_by_document_id(self, document_id: str) -> Document | None:
|
||||
"""Get by document ID.
|
||||
"""Retrieve a single document from the store, given its document ID.
|
||||
|
||||
Args:
|
||||
document_id: the document ID to get.
|
||||
document_id: The document ID
|
||||
|
||||
Returns:
|
||||
The the document if it exists. Otherwise None.
|
||||
"""
|
||||
row = await self.table.aget(row_id=document_id)
|
||||
if row is None:
|
||||
@ -524,28 +531,30 @@ class Cassandra(VectorStore):
|
||||
|
||||
def metadata_search(
|
||||
self,
|
||||
metadata: dict[str, Any] = {}, # noqa: B006
|
||||
filter: dict[str, Any] = {}, # noqa: B006
|
||||
n: int = 5,
|
||||
) -> Iterable[Document]:
|
||||
"""Get documents via a metadata search.
|
||||
|
||||
Args:
|
||||
metadata: the metadata to query for.
|
||||
filter: the metadata to query for.
|
||||
n: the maximum number of documents to return.
|
||||
"""
|
||||
rows = self.table.find_entries(metadata=metadata, n=n)
|
||||
rows = self.table.find_entries(metadata=filter, n=n)
|
||||
return [self._row_to_document(row=row) for row in rows if row]
|
||||
|
||||
async def ametadata_search(
|
||||
self,
|
||||
metadata: dict[str, Any] = {}, # noqa: B006
|
||||
filter: dict[str, Any] = {}, # noqa: B006
|
||||
n: int = 5,
|
||||
) -> Iterable[Document]:
|
||||
"""Get documents via a metadata search.
|
||||
|
||||
Args:
|
||||
metadata: the metadata to query for.
|
||||
filter: the metadata to query for.
|
||||
n: the maximum number of documents to return.
|
||||
"""
|
||||
rows = await self.table.afind_entries(metadata=metadata, n=n)
|
||||
rows = await self.table.afind_entries(metadata=filter, n=n)
|
||||
return [self._row_to_document(row=row) for row in rows]
|
||||
|
||||
async def asimilarity_search_with_embedding_id_by_vector(
|
||||
@ -1126,6 +1135,24 @@ class Cassandra(VectorStore):
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_docs_from_texts(
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
docs: List[Document] = []
|
||||
for i, text in enumerate(texts):
|
||||
doc = Document(
|
||||
page_content=text,
|
||||
)
|
||||
if metadatas is not None:
|
||||
doc.metadata = metadatas[i]
|
||||
if ids is not None:
|
||||
doc.id = ids[i]
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[CVST],
|
||||
@ -1137,13 +1164,12 @@ class Cassandra(VectorStore):
|
||||
keyspace: Optional[str] = None,
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
"""Create a Cassandra vector store from raw texts.
|
||||
|
||||
Args:
|
||||
texts: Texts to add to the vectorstore.
|
||||
@ -1155,16 +1181,32 @@ class Cassandra(VectorStore):
|
||||
If not provided, it is resolved from cassio.
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
batch_size: Number of concurrent requests to send to the server.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added texts.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
metadata_indexing: Optional specification of a metadata indexing policy,
|
||||
i.e. to fine-tune which of the metadata fields are indexed.
|
||||
It can be a string ("all" or "none"), or a 2-tuple. The following
|
||||
means that all fields except 'f1', 'f2' ... are NOT indexed:
|
||||
metadata_indexing=("allowlist", ["f1", "f2", ...])
|
||||
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
|
||||
metadata_indexing("denylist", ["g1", "g2", ...])
|
||||
The default is to index every metadata field.
|
||||
Note: if you plan to have massive unique text metadata entries,
|
||||
consider not indexing them for performance
|
||||
(and to overcome max-length limitations).
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
a Cassandra vector store.
|
||||
"""
|
||||
store = cls(
|
||||
docs = cls._build_docs_from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
return cls.from_documents(
|
||||
documents=docs,
|
||||
embedding=embedding,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
@ -1172,11 +1214,8 @@ class Cassandra(VectorStore):
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
store.add_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
|
||||
)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
@ -1189,13 +1228,12 @@ class Cassandra(VectorStore):
|
||||
keyspace: Optional[str] = None,
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
"""Create a Cassandra vector store from raw texts.
|
||||
|
||||
Args:
|
||||
texts: Texts to add to the vectorstore.
|
||||
@ -1207,29 +1245,51 @@ class Cassandra(VectorStore):
|
||||
If not provided, it is resolved from cassio.
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
concurrency: Number of concurrent queries to send to the database.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added texts.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
metadata_indexing: Optional specification of a metadata indexing policy,
|
||||
i.e. to fine-tune which of the metadata fields are indexed.
|
||||
It can be a string ("all" or "none"), or a 2-tuple. The following
|
||||
means that all fields except 'f1', 'f2' ... are NOT indexed:
|
||||
metadata_indexing=("allowlist", ["f1", "f2", ...])
|
||||
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
|
||||
metadata_indexing("denylist", ["g1", "g2", ...])
|
||||
The default is to index every metadata field.
|
||||
Note: if you plan to have massive unique text metadata entries,
|
||||
consider not indexing them for performance
|
||||
(and to overcome max-length limitations).
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
a Cassandra vector store.
|
||||
"""
|
||||
store = cls(
|
||||
docs = cls._build_docs_from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
return await cls.afrom_documents(
|
||||
documents=docs,
|
||||
embedding=embedding,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
await store.aadd_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
|
||||
)
|
||||
return store
|
||||
|
||||
@staticmethod
|
||||
def _add_ids_to_docs(
|
||||
docs: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
if ids is not None:
|
||||
for doc, doc_id in zip(docs, ids):
|
||||
doc.id = doc_id
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
@ -1241,13 +1301,12 @@ class Cassandra(VectorStore):
|
||||
keyspace: Optional[str] = None,
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
"""Create a Cassandra vector store from a document list.
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
@ -1258,31 +1317,48 @@ class Cassandra(VectorStore):
|
||||
If not provided, it is resolved from cassio.
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the documents.
|
||||
batch_size: Number of concurrent requests to send to the server.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added documents.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
metadata_indexing: Optional specification of a metadata indexing policy,
|
||||
i.e. to fine-tune which of the metadata fields are indexed.
|
||||
It can be a string ("all" or "none"), or a 2-tuple. The following
|
||||
means that all fields except 'f1', 'f2' ... are NOT indexed:
|
||||
metadata_indexing=("allowlist", ["f1", "f2", ...])
|
||||
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
|
||||
metadata_indexing("denylist", ["g1", "g2", ...])
|
||||
The default is to index every metadata field.
|
||||
Note: if you plan to have massive unique text metadata entries,
|
||||
consider not indexing them for performance
|
||||
(and to overcome max-length limitations).
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
a Cassandra vector store.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
if ids is not None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Parameter `ids` to Cassandra's `from_documents` "
|
||||
"method is deprecated. Please set the supplied documents' "
|
||||
"`.id` attribute instead. The id attribute of Document "
|
||||
"is ignored as long as the `ids` parameter is passed."
|
||||
),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
ids=ids,
|
||||
batch_size=batch_size,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
@ -1294,13 +1370,12 @@ class Cassandra(VectorStore):
|
||||
keyspace: Optional[str] = None,
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
"""Create a Cassandra vector store from a document list.
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
@ -1311,31 +1386,51 @@ class Cassandra(VectorStore):
|
||||
If not provided, it is resolved from cassio.
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the documents.
|
||||
concurrency: Number of concurrent queries to send to the database.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added documents.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
metadata_indexing: Optional specification of a metadata indexing policy,
|
||||
i.e. to fine-tune which of the metadata fields are indexed.
|
||||
It can be a string ("all" or "none"), or a 2-tuple. The following
|
||||
means that all fields except 'f1', 'f2' ... are NOT indexed:
|
||||
metadata_indexing=("allowlist", ["f1", "f2", ...])
|
||||
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
|
||||
metadata_indexing("denylist", ["g1", "g2", ...])
|
||||
The default is to index every metadata field.
|
||||
Note: if you plan to have massive unique text metadata entries,
|
||||
consider not indexing them for performance
|
||||
(and to overcome max-length limitations).
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
a Cassandra vector store.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await cls.afrom_texts(
|
||||
texts=texts,
|
||||
if ids is not None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Parameter `ids` to Cassandra's `afrom_documents` "
|
||||
"method is deprecated. Please set the supplied documents' "
|
||||
"`.id` attribute instead. The id attribute of Document "
|
||||
"is ignored as long as the `ids` parameter is passed."
|
||||
),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
ids=ids,
|
||||
concurrency=concurrency,
|
||||
ttl_seconds=ttl_seconds,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
await store.aadd_documents(
|
||||
documents=cls._add_ids_to_docs(docs=documents, ids=ids)
|
||||
)
|
||||
return store
|
||||
|
||||
def as_retriever(
|
||||
self,
|
||||
|
@ -1,116 +1,255 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List, Optional, Type
|
||||
"""Test of Apache Cassandra graph vector g_store class `CassandraGraphVectorStore`"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterable, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_community.graph_vectorstores.base import Node
|
||||
from langchain_community.graph_vectorstores.links import (
|
||||
METADATA_LINKS_KEY,
|
||||
Link,
|
||||
add_links,
|
||||
)
|
||||
from tests.integration_tests.cache.fake_embeddings import (
|
||||
AngularTwoDimensionalEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
|
||||
TEST_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
|
||||
def _get_graph_store(
|
||||
embedding_class: Type[Embeddings], documents: Iterable[Document] = ()
|
||||
) -> CassandraGraphVectorStore:
|
||||
import cassio
|
||||
from cassandra.cluster import Cluster
|
||||
from cassio.config import check_resolve_session, resolve_keyspace
|
||||
|
||||
node_table = "graph_test_node_table"
|
||||
edge_table = "graph_test_edge_table"
|
||||
|
||||
if any(
|
||||
env_var in os.environ
|
||||
for env_var in [
|
||||
"CASSANDRA_CONTACT_POINTS",
|
||||
"ASTRA_DB_APPLICATION_TOKEN",
|
||||
"ASTRA_DB_INIT_STRING",
|
||||
]
|
||||
):
|
||||
cassio.init(auto=True)
|
||||
session = check_resolve_session()
|
||||
else:
|
||||
cluster = Cluster()
|
||||
session = cluster.connect()
|
||||
keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE
|
||||
cassio.init(session=session, keyspace=keyspace)
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}")
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}")
|
||||
store = CassandraGraphVectorStore.from_documents(
|
||||
documents,
|
||||
embedding=embedding_class(),
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
node_table=node_table,
|
||||
targets_table=edge_table,
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
"""
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
try:
|
||||
angle = float(text)
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
except ValueError:
|
||||
# Assume: just test string, no attention is paid to values.
|
||||
return [0.0, 0.0]
|
||||
vals = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return [0.0] * self.dimension
|
||||
else:
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_d2() -> Embeddings:
|
||||
return ParserEmbeddings(dimension=2)
|
||||
|
||||
|
||||
class EarthEmbeddings(Embeddings):
|
||||
def get_vector_near(self, value: float) -> List[float]:
|
||||
base_point = [value, (1 - value**2) ** 0.5]
|
||||
fluctuation = random.random() / 100.0
|
||||
return [base_point[0] + fluctuation, base_point[1] - fluctuation]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
words = set(text.lower().split())
|
||||
if "earth" in words:
|
||||
vector = self.get_vector_near(0.9)
|
||||
elif {"planet", "world", "globe", "sphere"}.intersection(words):
|
||||
vector = self.get_vector_near(0.8)
|
||||
else:
|
||||
vector = self.get_vector_near(0.1)
|
||||
return vector
|
||||
|
||||
|
||||
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
|
||||
return [doc.id for doc in docs]
|
||||
|
||||
|
||||
def test_mmr_traversal() -> None:
|
||||
@pytest.fixture
|
||||
def graph_vector_store_docs() -> list[Document]:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
This is a set of Documents to pre-populate a graph vector store,
|
||||
with entries placed in a certain way.
|
||||
|
||||
Space of the entries (under Euclidean similarity):
|
||||
|
||||
A0 (*)
|
||||
.... AL AR <....
|
||||
: | :
|
||||
: | ^ :
|
||||
v | . v
|
||||
| :
|
||||
TR | : BL
|
||||
T0 --------------x-------------- B0
|
||||
TL | : BR
|
||||
| :
|
||||
| .
|
||||
| .
|
||||
|
|
||||
FL FR
|
||||
F0
|
||||
|
||||
the query point is meant to be at (*).
|
||||
the A are bidirectionally with B
|
||||
the A are outgoing to T
|
||||
the A are incoming from F
|
||||
The links are like: L with L, 0 with 0 and R with R.
|
||||
"""
|
||||
|
||||
docs_a = [
|
||||
Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}),
|
||||
Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}),
|
||||
Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}),
|
||||
]
|
||||
docs_b = [
|
||||
Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}),
|
||||
Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}),
|
||||
Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}),
|
||||
]
|
||||
docs_f = [
|
||||
Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}),
|
||||
Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}),
|
||||
Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}),
|
||||
]
|
||||
docs_t = [
|
||||
Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}),
|
||||
Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}),
|
||||
Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}),
|
||||
]
|
||||
for doc_a, suffix in zip(docs_a, ["l", "0", "r"]):
|
||||
add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
|
||||
add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}"))
|
||||
add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}"))
|
||||
for doc_b, suffix in zip(docs_b, ["l", "0", "r"]):
|
||||
add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}"))
|
||||
for doc_t, suffix in zip(docs_t, ["l", "0", "r"]):
|
||||
add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}"))
|
||||
for doc_f, suffix in zip(docs_f, ["l", "0", "r"]):
|
||||
add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}"))
|
||||
return docs_a + docs_b + docs_f + docs_t
|
||||
|
||||
|
||||
class CassandraSession:
|
||||
table_name: str
|
||||
session: Any
|
||||
|
||||
def __init__(self, table_name: str, session: Any):
|
||||
self.table_name = table_name
|
||||
self.session = session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_cassandra_session(
|
||||
table_name: str, drop: bool = True
|
||||
) -> Generator[CassandraSession, None, None]:
|
||||
"""Initialize the Cassandra cluster and session"""
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
|
||||
try:
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
|
||||
" WITH replication = "
|
||||
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||||
)
|
||||
)
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
|
||||
|
||||
# Yield the session for usage
|
||||
yield CassandraSession(table_name=table_name, session=session)
|
||||
finally:
|
||||
# Ensure proper shutdown/cleanup of resources
|
||||
session.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_angular(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=AngularTwoDimensionalEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_earth(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=EarthEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_fake(
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=FakeEmbeddings(),
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def graph_vector_store_d2(
|
||||
embedding_d2: Embeddings,
|
||||
table_name: str = "graph_test_table",
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
embedding=embedding_d2,
|
||||
session=session.session,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
table_name=session.table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def populated_graph_vector_store_d2(
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
graph_vector_store_docs: list[Document],
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
graph_vector_store_d2.add_documents(graph_vector_store_docs)
|
||||
yield graph_vector_store_d2
|
||||
|
||||
|
||||
def test_mmr_traversal(graph_vector_store_angular: CassandraGraphVectorStore) -> None:
|
||||
""" Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
@ -128,140 +267,128 @@ def test_mmr_traversal() -> None:
|
||||
Both v2 and v3 are reachable via edges from v0, so once it is
|
||||
selected, those are both considered.
|
||||
"""
|
||||
store = _get_graph_store(AngularTwoDimensionalEmbeddings)
|
||||
|
||||
v0 = Document(
|
||||
v0 = Node(
|
||||
id="v0",
|
||||
page_content="-0.124",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="-0.124",
|
||||
links=[
|
||||
Link.outgoing(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
v1 = Document(
|
||||
v1 = Node(
|
||||
id="v1",
|
||||
page_content="+0.127",
|
||||
text="+0.127",
|
||||
)
|
||||
v2 = Document(
|
||||
v2 = Node(
|
||||
id="v2",
|
||||
page_content="+0.25",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="+0.25",
|
||||
links=[
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
v3 = Document(
|
||||
v3 = Node(
|
||||
id="v3",
|
||||
page_content="+1.0",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
text="+1.0",
|
||||
links=[
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
)
|
||||
store.add_documents([v0, v1, v2, v3])
|
||||
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
|
||||
g_store = graph_vector_store_angular
|
||||
g_store.add_nodes([v0, v1, v2, v3])
|
||||
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
|
||||
# So it ends up picking "v1" even though it's similar to "v0".
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v1"]
|
||||
|
||||
# With max depth 0 but higher `fetch_k`, we encounter v2
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
|
||||
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
|
||||
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
|
||||
assert _result_ids(results) == ["v0"]
|
||||
|
||||
# with k=4 we should get all of the documents.
|
||||
results = store.mmr_traversal_search("0.0", k=4)
|
||||
results = g_store.mmr_traversal_search("0.0", k=4)
|
||||
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
|
||||
|
||||
|
||||
def test_write_retrieve_keywords() -> None:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
greetings = Document(
|
||||
def test_write_retrieve_keywords(
|
||||
graph_vector_store_earth: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
greetings = Node(
|
||||
id="greetings",
|
||||
page_content="Typical Greetings",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="parent", tag="parent"),
|
||||
],
|
||||
},
|
||||
text="Typical Greetings",
|
||||
links=[
|
||||
Link.incoming(kind="parent", tag="parent"),
|
||||
],
|
||||
)
|
||||
doc1 = Document(
|
||||
|
||||
node1 = Node(
|
||||
id="doc1",
|
||||
page_content="Hello World",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="world"),
|
||||
],
|
||||
},
|
||||
text="Hello World",
|
||||
links=[
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="world"),
|
||||
],
|
||||
)
|
||||
doc2 = Document(
|
||||
|
||||
node2 = Node(
|
||||
id="doc2",
|
||||
page_content="Hello Earth",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="earth"),
|
||||
],
|
||||
},
|
||||
text="Hello Earth",
|
||||
links=[
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="earth"),
|
||||
],
|
||||
)
|
||||
store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2])
|
||||
|
||||
g_store = graph_vector_store_earth
|
||||
g_store.add_nodes(nodes=[greetings, node1, node2])
|
||||
|
||||
# Doc2 is more similar, but World and Earth are similar enough that doc1 also
|
||||
# shows up.
|
||||
results: Iterable[Document] = store.similarity_search("Earth", k=2)
|
||||
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.similarity_search("Earth", k=1)
|
||||
results = g_store.similarity_search("Earth", k=1)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=0)
|
||||
results = g_store.traversal_search("Earth", k=2, depth=0)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=1)
|
||||
results = g_store.traversal_search("Earth", k=2, depth=1)
|
||||
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth)
|
||||
results = store.traversal_search("Earth", k=1, depth=0)
|
||||
results = g_store.traversal_search("Earth", k=1, depth=0)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
|
||||
# keyword edge.
|
||||
results = store.traversal_search("Earth", k=1, depth=1)
|
||||
results = g_store.traversal_search("Earth", k=1, depth=1)
|
||||
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
|
||||
|
||||
|
||||
def test_metadata() -> None:
|
||||
store = _get_graph_store(FakeEmbeddings)
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="A",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
],
|
||||
"other": "some other field",
|
||||
},
|
||||
)
|
||||
]
|
||||
def test_metadata(graph_vector_store_fake: CassandraGraphVectorStore) -> None:
|
||||
doc_a = Node(
|
||||
id="a",
|
||||
text="A",
|
||||
metadata={"other": "some other field"},
|
||||
links=[
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
],
|
||||
)
|
||||
results = store.similarity_search("A")
|
||||
|
||||
g_store = graph_vector_store_fake
|
||||
g_store.add_nodes([doc_a])
|
||||
results = g_store.similarity_search("A")
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "a"
|
||||
metadata = results[0].metadata
|
||||
@ -270,3 +397,274 @@ def test_metadata() -> None:
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
}
|
||||
|
||||
|
||||
class TestCassandraGraphVectorStore:
|
||||
def test_gvs_similarity_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Simple (non-graph) similarity search on a graph vector g_store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ss_response = g_store.similarity_search(query="[2, 10]", k=2)
|
||||
ss_labels = [doc.metadata["label"] for doc in ss_response]
|
||||
assert ss_labels == ["AR", "A0"]
|
||||
ss_by_v_response = g_store.similarity_search_by_vector(embedding=[2, 10], k=2)
|
||||
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
|
||||
assert ss_by_v_labels == ["AR", "A0"]
|
||||
|
||||
async def test_gvs_similarity_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Simple (non-graph) similarity search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ss_response = await g_store.asimilarity_search(query="[2, 10]", k=2)
|
||||
ss_labels = [doc.metadata["label"] for doc in ss_response]
|
||||
assert ss_labels == ["AR", "A0"]
|
||||
ss_by_v_response = await g_store.asimilarity_search_by_vector(
|
||||
embedding=[2, 10], k=2
|
||||
)
|
||||
ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response]
|
||||
assert ss_by_v_labels == ["AR", "A0"]
|
||||
|
||||
def test_gvs_traversal_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ts_response = g_store.traversal_search(query="[2, 10]", k=2, depth=2)
|
||||
# this is a set, as some of the internals of trav.search are set-driven
|
||||
# so ordering is not deterministic:
|
||||
ts_labels = {doc.metadata["label"] for doc in ts_response}
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
async def test_gvs_traversal_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
ts_labels = set()
|
||||
async for doc in g_store.atraversal_search(query="[2, 10]", k=2, depth=2):
|
||||
ts_labels.add(doc.metadata["label"])
|
||||
# this is a set, as some of the internals of trav.search are set-driven
|
||||
# so ordering is not deterministic:
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
def test_gvs_mmr_traversal_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""MMR Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = g_store.mmr_traversal_search(
|
||||
query="[2, 10]",
|
||||
k=2,
|
||||
depth=2,
|
||||
fetch_k=1,
|
||||
adjacent_k=2,
|
||||
lambda_mult=0.1,
|
||||
)
|
||||
# TODO: can this rightfully be a list (or must it be a set)?
|
||||
mt_labels = {doc.metadata["label"] for doc in mt_response}
|
||||
assert mt_labels == {"AR", "BR"}
|
||||
|
||||
async def test_gvs_mmr_traversal_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""MMR Graph traversal search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_labels = set()
|
||||
async for doc in g_store.ammr_traversal_search(
|
||||
query="[2, 10]",
|
||||
k=2,
|
||||
depth=2,
|
||||
fetch_k=1,
|
||||
adjacent_k=2,
|
||||
lambda_mult=0.1,
|
||||
):
|
||||
mt_labels.add(doc.metadata["label"])
|
||||
# TODO: can this rightfully be a list (or must it be a set)?
|
||||
assert mt_labels == {"AR", "BR"}
|
||||
|
||||
def test_gvs_metadata_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Metadata search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = g_store.metadata_search(
|
||||
filter={"label": "T0"},
|
||||
n=2,
|
||||
)
|
||||
doc: Document = next(iter(mt_response))
|
||||
assert doc.page_content == "[-10, 0]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "in"
|
||||
assert link.kind == "at_example"
|
||||
assert link.tag == "tag_0"
|
||||
|
||||
async def test_gvs_metadata_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Metadata search on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
mt_response = await g_store.ametadata_search(
|
||||
filter={"label": "T0"},
|
||||
n=2,
|
||||
)
|
||||
doc: Document = next(iter(mt_response))
|
||||
assert doc.page_content == "[-10, 0]"
|
||||
links: set[Link] = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "in"
|
||||
assert link.kind == "at_example"
|
||||
assert link.tag == "tag_0"
|
||||
|
||||
def test_gvs_get_by_document_id_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Get by document_id on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
doc = g_store.get_by_document_id(document_id="FL")
|
||||
assert doc is not None
|
||||
assert doc.page_content == "[1, -9]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "out"
|
||||
assert link.kind == "af_example"
|
||||
assert link.tag == "tag_l"
|
||||
|
||||
invalid_doc = g_store.get_by_document_id(document_id="invalid")
|
||||
assert invalid_doc is None
|
||||
|
||||
async def test_gvs_get_by_document_id_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
"""Get by document_id on a graph vector store."""
|
||||
g_store = populated_graph_vector_store_d2
|
||||
doc = await g_store.aget_by_document_id(document_id="FL")
|
||||
assert doc is not None
|
||||
assert doc.page_content == "[1, -9]"
|
||||
links = doc.metadata["links"]
|
||||
assert len(links) == 1
|
||||
link: Link = links.pop()
|
||||
assert isinstance(link, Link)
|
||||
assert link.direction == "out"
|
||||
assert link.kind == "af_example"
|
||||
assert link.tag == "tag_l"
|
||||
|
||||
invalid_doc = await g_store.aget_by_document_id(document_id="invalid")
|
||||
assert invalid_doc is None
|
||||
|
||||
def test_gvs_from_texts(
|
||||
self,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
g_store = graph_vector_store_d2
|
||||
g_store.add_texts(
|
||||
texts=["[1, 2]"],
|
||||
metadatas=[{"md": 1}],
|
||||
ids=["x_id"],
|
||||
)
|
||||
|
||||
hits = g_store.similarity_search("[2, 1]", k=2)
|
||||
assert len(hits) == 1
|
||||
assert hits[0].page_content == "[1, 2]"
|
||||
assert hits[0].id == "x_id"
|
||||
# there may be more re:graph structure.
|
||||
assert hits[0].metadata["md"] == "1.0"
|
||||
|
||||
def test_gvs_from_documents_containing_ids(
|
||||
self,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
the_document = Document(
|
||||
page_content="[1, 2]",
|
||||
metadata={"md": 1},
|
||||
id="x_id",
|
||||
)
|
||||
g_store = graph_vector_store_d2
|
||||
g_store.add_documents([the_document])
|
||||
hits = g_store.similarity_search("[2, 1]", k=2)
|
||||
assert len(hits) == 1
|
||||
assert hits[0].page_content == "[1, 2]"
|
||||
assert hits[0].id == "x_id"
|
||||
# there may be more re:graph structure.
|
||||
assert hits[0].metadata["md"] == "1.0"
|
||||
|
||||
def test_gvs_add_nodes_sync(
|
||||
self,
|
||||
*,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
links0 = [
|
||||
Link(kind="kA", direction="out", tag="tA"),
|
||||
Link(kind="kB", direction="bidir", tag="tB"),
|
||||
]
|
||||
links1 = [
|
||||
Link(kind="kC", direction="in", tag="tC"),
|
||||
]
|
||||
nodes = [
|
||||
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
|
||||
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
|
||||
]
|
||||
graph_vector_store_d2.add_nodes(nodes)
|
||||
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
|
||||
assert len(hits) == 2
|
||||
assert hits[0].id == "id0"
|
||||
assert hits[0].page_content == "[1, 0]"
|
||||
md0 = hits[0].metadata
|
||||
assert md0["m"] == "0.0"
|
||||
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
|
||||
|
||||
assert hits[1].id != "id0"
|
||||
assert hits[1].page_content == "[-1, 0]"
|
||||
md1 = hits[1].metadata
|
||||
assert md1["m"] == "1.0"
|
||||
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
|
||||
|
||||
async def test_gvs_add_nodes_async(
|
||||
self,
|
||||
*,
|
||||
graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
) -> None:
|
||||
links0 = [
|
||||
Link(kind="kA", direction="out", tag="tA"),
|
||||
Link(kind="kB", direction="bidir", tag="tB"),
|
||||
]
|
||||
links1 = [
|
||||
Link(kind="kC", direction="in", tag="tC"),
|
||||
]
|
||||
nodes = [
|
||||
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
|
||||
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
|
||||
]
|
||||
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
|
||||
pass
|
||||
|
||||
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
|
||||
assert len(hits) == 2
|
||||
assert hits[0].id == "id0"
|
||||
assert hits[0].page_content == "[1, 0]"
|
||||
md0 = hits[0].metadata
|
||||
assert md0["m"] == "0.0"
|
||||
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
|
||||
assert hits[1].id != "id0"
|
||||
assert hits[1].page_content == "[-1, 0]"
|
||||
md1 = hits[1].metadata
|
||||
assert md1["m"] == "1.0"
|
||||
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
|
||||
|
@ -0,0 +1,269 @@
|
||||
"""Test of Upgrading to Apache Cassandra graph vector store class:
|
||||
`CassandraGraphVectorStore` from an existing table used
|
||||
by the Cassandra vector store class: `Cassandra`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterable, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_community.vectorstores import Cassandra
|
||||
|
||||
TEST_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
TABLE_NAME_ALLOW_INDEXING = "allow_graph_table"
|
||||
TABLE_NAME_DEFAULT = "default_graph_table"
|
||||
TABLE_NAME_DENY_INDEXING = "deny_graph_table"
|
||||
|
||||
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
try:
|
||||
vals = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return [0.0] * self.dimension
|
||||
else:
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_d2() -> Embeddings:
|
||||
return ParserEmbeddings(dimension=2)
|
||||
|
||||
|
||||
class CassandraSession:
|
||||
table_name: str
|
||||
session: Any
|
||||
|
||||
def __init__(self, table_name: str, session: Any):
|
||||
self.table_name = table_name
|
||||
self.session = session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_cassandra_session(
|
||||
table_name: str, drop: bool = True
|
||||
) -> Generator[CassandraSession, None, None]:
|
||||
"""Initialize the Cassandra cluster and session"""
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
|
||||
try:
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
|
||||
" WITH replication = "
|
||||
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
|
||||
)
|
||||
)
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")
|
||||
|
||||
# Yield the session for usage
|
||||
yield CassandraSession(table_name=table_name, session=session)
|
||||
finally:
|
||||
# Ensure proper shutdown/cleanup of resources
|
||||
session.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vector_store(
|
||||
embedding: Embeddings,
|
||||
table_name: str,
|
||||
setup_mode: SetupMode,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
drop: bool = True,
|
||||
) -> Generator[Cassandra, None, None]:
|
||||
with get_cassandra_session(table_name=table_name, drop=drop) as session:
|
||||
yield Cassandra(
|
||||
table_name=session.table_name,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
session=session.session,
|
||||
embedding=embedding,
|
||||
setup_mode=setup_mode,
|
||||
metadata_indexing=metadata_indexing,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_vector_store(
|
||||
embedding: Embeddings,
|
||||
table_name: str,
|
||||
setup_mode: SetupMode,
|
||||
metadata_deny_list: Optional[list[str]] = None,
|
||||
drop: bool = True,
|
||||
) -> Generator[CassandraGraphVectorStore, None, None]:
|
||||
with get_cassandra_session(table_name=table_name, drop=drop) as session:
|
||||
yield CassandraGraphVectorStore(
|
||||
table_name=session.table_name,
|
||||
keyspace=TEST_KEYSPACE,
|
||||
session=session.session,
|
||||
embedding=embedding,
|
||||
setup_mode=setup_mode,
|
||||
metadata_deny_list=metadata_deny_list,
|
||||
)
|
||||
|
||||
|
||||
def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]:
|
||||
if table_name == TABLE_NAME_ALLOW_INDEXING:
|
||||
return ("allowlist", ["test"])
|
||||
if table_name == TABLE_NAME_DEFAULT:
|
||||
return "all"
|
||||
if table_name == TABLE_NAME_DENY_INDEXING:
|
||||
return ("denylist", ["test"])
|
||||
msg = f"Unknown table_name: {table_name} in _vs_indexing_policy()"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class TestUpgradeToGraphVectorStore:
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
|
||||
[
|
||||
(TABLE_NAME_DEFAULT, SetupMode.SYNC, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.SYNC, ["test"]),
|
||||
(TABLE_NAME_DEFAULT, SetupMode.OFF, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, ["test"]),
|
||||
# for this one, even though the passed policy doesn't
|
||||
# match the policy used to create the collection,
|
||||
# there is no error since the SetupMode is OFF and
|
||||
# and no attempt is made to re-create the collection.
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.OFF, None),
|
||||
],
|
||||
ids=[
|
||||
"default_upgrade_no_policy_sync",
|
||||
"deny_list_upgrade_same_policy_sync",
|
||||
"default_upgrade_no_policy_off",
|
||||
"deny_list_upgrade_same_policy_off",
|
||||
"deny_list_upgrade_change_policy_off",
|
||||
],
|
||||
)
|
||||
def test_upgrade_to_gvs_success_sync(
|
||||
self,
|
||||
*,
|
||||
embedding_d2: Embeddings,
|
||||
gvs_setup_mode: SetupMode,
|
||||
table_name: str,
|
||||
gvs_metadata_deny_list: list[str],
|
||||
) -> None:
|
||||
doc_id = "AL"
|
||||
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
|
||||
|
||||
# Create vector store using SetupMode.SYNC
|
||||
with vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=SetupMode.SYNC,
|
||||
metadata_indexing=_vs_indexing_policy(table_name=table_name),
|
||||
drop=True,
|
||||
) as v_store:
|
||||
# load a document to the vector store
|
||||
v_store.add_documents([doc_al])
|
||||
|
||||
# get the document from the vector store
|
||||
v_doc = v_store.get_by_document_id(document_id=doc_id)
|
||||
assert v_doc is not None
|
||||
assert v_doc.page_content == doc_al.page_content
|
||||
|
||||
# Create a GRAPH Vector Store using the existing collection from above
|
||||
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
|
||||
with graph_vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=gvs_setup_mode,
|
||||
metadata_deny_list=gvs_metadata_deny_list,
|
||||
drop=False,
|
||||
) as gv_store:
|
||||
# get the document from the GRAPH vector store
|
||||
gv_doc = gv_store.get_by_document_id(document_id=doc_id)
|
||||
assert gv_doc is not None
|
||||
assert gv_doc.page_content == doc_al.page_content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"),
|
||||
[
|
||||
(TABLE_NAME_DEFAULT, SetupMode.ASYNC, None),
|
||||
(TABLE_NAME_DENY_INDEXING, SetupMode.ASYNC, ["test"]),
|
||||
],
|
||||
ids=[
|
||||
"default_upgrade_no_policy_async",
|
||||
"deny_list_upgrade_same_policy_async",
|
||||
],
|
||||
)
|
||||
async def test_upgrade_to_gvs_success_async(
|
||||
self,
|
||||
*,
|
||||
embedding_d2: Embeddings,
|
||||
gvs_setup_mode: SetupMode,
|
||||
table_name: str,
|
||||
gvs_metadata_deny_list: list[str],
|
||||
) -> None:
|
||||
doc_id = "AL"
|
||||
doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"})
|
||||
|
||||
# Create vector store using SetupMode.ASYNC
|
||||
with vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
metadata_indexing=_vs_indexing_policy(table_name=table_name),
|
||||
drop=True,
|
||||
) as v_store:
|
||||
# load a document to the vector store
|
||||
await v_store.aadd_documents([doc_al])
|
||||
|
||||
# get the document from the vector store
|
||||
v_doc = await v_store.aget_by_document_id(document_id=doc_id)
|
||||
assert v_doc is not None
|
||||
assert v_doc.page_content == doc_al.page_content
|
||||
|
||||
# Create a GRAPH Vector Store using the existing collection from above
|
||||
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
|
||||
with graph_vector_store(
|
||||
embedding=embedding_d2,
|
||||
table_name=table_name,
|
||||
setup_mode=gvs_setup_mode,
|
||||
metadata_deny_list=gvs_metadata_deny_list,
|
||||
drop=False,
|
||||
) as gv_store:
|
||||
# get the document from the GRAPH vector store
|
||||
gv_doc = await gv_store.aget_by_document_id(document_id=doc_id)
|
||||
assert gv_doc is not None
|
||||
assert gv_doc.page_content == doc_al.page_content
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
|
||||
|
||||
IDS = {
|
||||
"-1",
|
||||
"-2",
|
||||
"-3",
|
||||
"-4",
|
||||
"-5",
|
||||
"+1",
|
||||
"+2",
|
||||
"+3",
|
||||
"+4",
|
||||
"+5",
|
||||
}
|
||||
|
||||
|
||||
class TestMmrHelper:
|
||||
def test_mmr_helper_functional(self) -> None:
|
||||
helper = MmrHelper(k=3, query_embedding=[6, 5], lambda_mult=0.5)
|
||||
|
||||
assert len(list(helper.candidate_ids())) == 0
|
||||
|
||||
helper.add_candidates({"-1": [3, 5]})
|
||||
helper.add_candidates({"-2": [3, 5]})
|
||||
helper.add_candidates({"-3": [2, 6]})
|
||||
helper.add_candidates({"-4": [1, 6]})
|
||||
helper.add_candidates({"-5": [0, 6]})
|
||||
|
||||
assert len(list(helper.candidate_ids())) == 5
|
||||
|
||||
helper.add_candidates({"+1": [5, 3]})
|
||||
helper.add_candidates({"+2": [5, 3]})
|
||||
helper.add_candidates({"+3": [6, 2]})
|
||||
helper.add_candidates({"+4": [6, 1]})
|
||||
helper.add_candidates({"+5": [6, 0]})
|
||||
|
||||
assert len(list(helper.candidate_ids())) == 10
|
||||
|
||||
for idx in range(3):
|
||||
best_id = helper.pop_best()
|
||||
assert best_id in IDS
|
||||
assert len(list(helper.candidate_ids())) == 9 - idx
|
||||
assert best_id not in helper.candidate_ids()
|
||||
|
||||
def test_mmr_helper_max_diversity(self) -> None:
|
||||
helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=0)
|
||||
helper.add_candidates({"-1": [3, 5]})
|
||||
helper.add_candidates({"-2": [3, 5]})
|
||||
helper.add_candidates({"-3": [2, 6]})
|
||||
helper.add_candidates({"-4": [1, 6]})
|
||||
helper.add_candidates({"-5": [0, 6]})
|
||||
|
||||
best = {helper.pop_best(), helper.pop_best()}
|
||||
assert best == {"-1", "-5"}
|
||||
|
||||
def test_mmr_helper_max_similarity(self) -> None:
|
||||
helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=1)
|
||||
helper.add_candidates({"-1": [3, 5]})
|
||||
helper.add_candidates({"-2": [3, 5]})
|
||||
helper.add_candidates({"-3": [2, 6]})
|
||||
helper.add_candidates({"-4": [1, 6]})
|
||||
helper.add_candidates({"-5": [0, 6]})
|
||||
|
||||
best = {helper.pop_best(), helper.pop_best()}
|
||||
assert best == {"-1", "-2"}
|
||||
|
||||
def test_mmr_helper_add_candidate(self) -> None:
|
||||
helper = MmrHelper(5, [0.0, 1.0])
|
||||
helper.add_candidates(
|
||||
{
|
||||
"a": [0.0, 1.0],
|
||||
"b": [1.0, 0.0],
|
||||
}
|
||||
)
|
||||
assert helper.best_id == "a"
|
||||
|
||||
def test_mmr_helper_pop_best(self) -> None:
|
||||
helper = MmrHelper(5, [0.0, 1.0])
|
||||
helper.add_candidates(
|
||||
{
|
||||
"a": [0.0, 1.0],
|
||||
"b": [1.0, 0.0],
|
||||
}
|
||||
)
|
||||
assert helper.pop_best() == "a"
|
||||
assert helper.pop_best() == "b"
|
||||
assert helper.pop_best() is None
|
||||
|
||||
def angular_embedding(self, angle: float) -> list[float]:
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
|
||||
def test_mmr_helper_added_documents(self) -> None:
|
||||
"""Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
______ v2
|
||||
/ \
|
||||
/ | v1
|
||||
v3 | . | query
|
||||
| / v0
|
||||
|______/ (N.B. very crude drawing)
|
||||
|
||||
|
||||
With fetch_k==2 and k==2, when query is at 0.0, (1, ),
|
||||
one expects that v2 and v0 are returned (in some order)
|
||||
because v1 is "too close" to v0 (and v0 is closer than v1)).
|
||||
|
||||
Both v2 and v3 are discovered after v0.
|
||||
"""
|
||||
helper = MmrHelper(5, self.angular_embedding(0.0))
|
||||
|
||||
# Fetching the 2 nearest neighbors to 0.0
|
||||
helper.add_candidates(
|
||||
{
|
||||
"v0": self.angular_embedding(-0.124),
|
||||
"v1": self.angular_embedding(+0.127),
|
||||
}
|
||||
)
|
||||
assert helper.pop_best() == "v0"
|
||||
|
||||
# After v0 is selected, new nodes are discovered.
|
||||
# v2 is closer than v3. v1 is "too similar" to "v0" so it's not included.
|
||||
helper.add_candidates(
|
||||
{
|
||||
"v2": self.angular_embedding(+0.25),
|
||||
"v3": self.angular_embedding(+1.0),
|
||||
}
|
||||
)
|
||||
assert helper.pop_best() == "v2"
|
||||
|
||||
assert math.isclose(
|
||||
helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001
|
||||
)
|
||||
assert math.isclose(
|
||||
helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001
|
||||
)
|
||||
assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001)
|
||||
assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001)
|
Loading…
Reference in New Issue
Block a user