mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
community[patch]: Use newer MetadataVectorCassandraTable in Cassandra vector store (#15987)
as VectorTable is deprecated Tested manually with `test_cassandra.py` vector store integration test.
This commit is contained in:
parent
1fa056c324
commit
fb940d11df
@ -75,7 +75,7 @@ class Cassandra(VectorStore):
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from cassio.vector import VectorTable
|
||||
from cassio.table import MetadataVectorCassandraTable
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import cassio python package. "
|
||||
@ -90,11 +90,12 @@ class Cassandra(VectorStore):
|
||||
#
|
||||
self._embedding_dimension = None
|
||||
#
|
||||
self.table = VectorTable(
|
||||
self.table = MetadataVectorCassandraTable(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table=table_name,
|
||||
embedding_dimension=self._get_embedding_dimension(),
|
||||
vector_dimension=self._get_embedding_dimension(),
|
||||
metadata_indexing="all",
|
||||
primary_key_type="TEXT",
|
||||
)
|
||||
|
||||
@ -127,7 +128,7 @@ class Cassandra(VectorStore):
|
||||
self.table.clear()
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
return self.table.delete(document_id)
|
||||
return self.table.delete(row_id=document_id)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector IDs.
|
||||
@ -188,7 +189,11 @@ class Cassandra(VectorStore):
|
||||
|
||||
futures = [
|
||||
self.table.put_async(
|
||||
text, embedding_vector, text_id, metadata, ttl_seconds
|
||||
row_id=text_id,
|
||||
body_blob=text,
|
||||
vector=embedding_vector,
|
||||
metadata=metadata or {},
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
for text, embedding_vector, text_id, metadata in zip(
|
||||
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
|
||||
@ -215,11 +220,10 @@ class Cassandra(VectorStore):
|
||||
"""
|
||||
search_metadata = self._filter_to_metadata(filter)
|
||||
#
|
||||
hits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=k,
|
||||
hits = self.table.metric_ann_search(
|
||||
vector=embedding,
|
||||
n=k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
metadata=search_metadata,
|
||||
)
|
||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||
@ -227,11 +231,11 @@ class Cassandra(VectorStore):
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["document"],
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
0.5 + 0.5 * hit["distance"],
|
||||
hit["document_id"],
|
||||
hit["row_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
@ -340,31 +344,32 @@ class Cassandra(VectorStore):
|
||||
"""
|
||||
search_metadata = self._filter_to_metadata(filter)
|
||||
|
||||
prefetchHits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=fetch_k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
metadata=search_metadata,
|
||||
prefetch_hits = list(
|
||||
self.table.metric_ann_search(
|
||||
vector=embedding,
|
||||
n=fetch_k,
|
||||
metric="cos",
|
||||
metadata=search_metadata,
|
||||
)
|
||||
)
|
||||
# let the mmr utility pick the *indices* in the above array
|
||||
mmrChosenIndices = maximal_marginal_relevance(
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[pfHit["embedding_vector"] for pfHit in prefetchHits],
|
||||
[pf_hit["vector"] for pf_hit in prefetch_hits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmrHits = [
|
||||
pfHit
|
||||
for pfIndex, pfHit in enumerate(prefetchHits)
|
||||
if pfIndex in mmrChosenIndices
|
||||
mmr_hits = [
|
||||
pf_hit
|
||||
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||
if pf_index in mmr_chosen_indices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["document"],
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
)
|
||||
for hit in mmrHits
|
||||
for hit in mmr_hits
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
|
Loading…
Reference in New Issue
Block a user