community[patch]: Use newer MetadataVectorCassandraTable in Cassandra vector store (#15987)

as VectorTable is deprecated

Tested manually with `test_cassandra.py` vector store integration test.
This commit is contained in:
Christophe Bornet 2024-01-17 19:37:07 +01:00 committed by GitHub
parent 1fa056c324
commit fb940d11df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -75,7 +75,7 @@ class Cassandra(VectorStore):
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
try: try:
from cassio.vector import VectorTable from cassio.table import MetadataVectorCassandraTable
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
raise ImportError( raise ImportError(
"Could not import cassio python package. " "Could not import cassio python package. "
@ -90,11 +90,12 @@ class Cassandra(VectorStore):
# #
self._embedding_dimension = None self._embedding_dimension = None
# #
self.table = VectorTable( self.table = MetadataVectorCassandraTable(
session=session, session=session,
keyspace=keyspace, keyspace=keyspace,
table=table_name, table=table_name,
embedding_dimension=self._get_embedding_dimension(), vector_dimension=self._get_embedding_dimension(),
metadata_indexing="all",
primary_key_type="TEXT", primary_key_type="TEXT",
) )
@ -127,7 +128,7 @@ class Cassandra(VectorStore):
self.table.clear() self.table.clear()
def delete_by_document_id(self, document_id: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id) return self.table.delete(row_id=document_id)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector IDs. """Delete by vector IDs.
@ -188,7 +189,11 @@ class Cassandra(VectorStore):
futures = [ futures = [
self.table.put_async( self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds row_id=text_id,
body_blob=text,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
) )
for text, embedding_vector, text_id, metadata in zip( for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
@ -215,11 +220,10 @@ class Cassandra(VectorStore):
""" """
search_metadata = self._filter_to_metadata(filter) search_metadata = self._filter_to_metadata(filter)
# #
hits = self.table.search( hits = self.table.metric_ann_search(
embedding_vector=embedding, vector=embedding,
top_k=k, n=k,
metric="cos", metric="cos",
metric_threshold=None,
metadata=search_metadata, metadata=search_metadata,
) )
# We stick to 'cos' distance as it can be normalized on a 0-1 axis # We stick to 'cos' distance as it can be normalized on a 0-1 axis
@ -227,11 +231,11 @@ class Cassandra(VectorStore):
return [ return [
( (
Document( Document(
page_content=hit["document"], page_content=hit["body_blob"],
metadata=hit["metadata"], metadata=hit["metadata"],
), ),
0.5 + 0.5 * hit["distance"], 0.5 + 0.5 * hit["distance"],
hit["document_id"], hit["row_id"],
) )
for hit in hits for hit in hits
] ]
@ -340,31 +344,32 @@ class Cassandra(VectorStore):
""" """
search_metadata = self._filter_to_metadata(filter) search_metadata = self._filter_to_metadata(filter)
prefetchHits = self.table.search( prefetch_hits = list(
embedding_vector=embedding, self.table.metric_ann_search(
top_k=fetch_k, vector=embedding,
n=fetch_k,
metric="cos", metric="cos",
metric_threshold=None,
metadata=search_metadata, metadata=search_metadata,
) )
)
# let the mmr utility pick the *indices* in the above array # let the mmr utility pick the *indices* in the above array
mmrChosenIndices = maximal_marginal_relevance( mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), np.array(embedding, dtype=np.float32),
[pfHit["embedding_vector"] for pfHit in prefetchHits], [pf_hit["vector"] for pf_hit in prefetch_hits],
k=k, k=k,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
) )
mmrHits = [ mmr_hits = [
pfHit pf_hit
for pfIndex, pfHit in enumerate(prefetchHits) for pf_index, pf_hit in enumerate(prefetch_hits)
if pfIndex in mmrChosenIndices if pf_index in mmr_chosen_indices
] ]
return [ return [
Document( Document(
page_content=hit["document"], page_content=hit["body_blob"],
metadata=hit["metadata"], metadata=hit["metadata"],
) )
for hit in mmrHits for hit in mmr_hits
] ]
def max_marginal_relevance_search( def max_marginal_relevance_search(