mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
community[patch]: Use newer MetadataVectorCassandraTable in Cassandra vector store (#15987)
as VectorTable is deprecated Tested manually with `test_cassandra.py` vector store integration test.
This commit is contained in:
parent
1fa056c324
commit
fb940d11df
@ -75,7 +75,7 @@ class Cassandra(VectorStore):
|
|||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
from cassio.vector import VectorTable
|
from cassio.table import MetadataVectorCassandraTable
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import cassio python package. "
|
"Could not import cassio python package. "
|
||||||
@ -90,11 +90,12 @@ class Cassandra(VectorStore):
|
|||||||
#
|
#
|
||||||
self._embedding_dimension = None
|
self._embedding_dimension = None
|
||||||
#
|
#
|
||||||
self.table = VectorTable(
|
self.table = MetadataVectorCassandraTable(
|
||||||
session=session,
|
session=session,
|
||||||
keyspace=keyspace,
|
keyspace=keyspace,
|
||||||
table=table_name,
|
table=table_name,
|
||||||
embedding_dimension=self._get_embedding_dimension(),
|
vector_dimension=self._get_embedding_dimension(),
|
||||||
|
metadata_indexing="all",
|
||||||
primary_key_type="TEXT",
|
primary_key_type="TEXT",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ class Cassandra(VectorStore):
|
|||||||
self.table.clear()
|
self.table.clear()
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str) -> None:
|
def delete_by_document_id(self, document_id: str) -> None:
|
||||||
return self.table.delete(document_id)
|
return self.table.delete(row_id=document_id)
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||||
"""Delete by vector IDs.
|
"""Delete by vector IDs.
|
||||||
@ -188,7 +189,11 @@ class Cassandra(VectorStore):
|
|||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
self.table.put_async(
|
self.table.put_async(
|
||||||
text, embedding_vector, text_id, metadata, ttl_seconds
|
row_id=text_id,
|
||||||
|
body_blob=text,
|
||||||
|
vector=embedding_vector,
|
||||||
|
metadata=metadata or {},
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
)
|
)
|
||||||
for text, embedding_vector, text_id, metadata in zip(
|
for text, embedding_vector, text_id, metadata in zip(
|
||||||
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
|
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
|
||||||
@ -215,11 +220,10 @@ class Cassandra(VectorStore):
|
|||||||
"""
|
"""
|
||||||
search_metadata = self._filter_to_metadata(filter)
|
search_metadata = self._filter_to_metadata(filter)
|
||||||
#
|
#
|
||||||
hits = self.table.search(
|
hits = self.table.metric_ann_search(
|
||||||
embedding_vector=embedding,
|
vector=embedding,
|
||||||
top_k=k,
|
n=k,
|
||||||
metric="cos",
|
metric="cos",
|
||||||
metric_threshold=None,
|
|
||||||
metadata=search_metadata,
|
metadata=search_metadata,
|
||||||
)
|
)
|
||||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||||
@ -227,11 +231,11 @@ class Cassandra(VectorStore):
|
|||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
Document(
|
Document(
|
||||||
page_content=hit["document"],
|
page_content=hit["body_blob"],
|
||||||
metadata=hit["metadata"],
|
metadata=hit["metadata"],
|
||||||
),
|
),
|
||||||
0.5 + 0.5 * hit["distance"],
|
0.5 + 0.5 * hit["distance"],
|
||||||
hit["document_id"],
|
hit["row_id"],
|
||||||
)
|
)
|
||||||
for hit in hits
|
for hit in hits
|
||||||
]
|
]
|
||||||
@ -340,31 +344,32 @@ class Cassandra(VectorStore):
|
|||||||
"""
|
"""
|
||||||
search_metadata = self._filter_to_metadata(filter)
|
search_metadata = self._filter_to_metadata(filter)
|
||||||
|
|
||||||
prefetchHits = self.table.search(
|
prefetch_hits = list(
|
||||||
embedding_vector=embedding,
|
self.table.metric_ann_search(
|
||||||
top_k=fetch_k,
|
vector=embedding,
|
||||||
|
n=fetch_k,
|
||||||
metric="cos",
|
metric="cos",
|
||||||
metric_threshold=None,
|
|
||||||
metadata=search_metadata,
|
metadata=search_metadata,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# let the mmr utility pick the *indices* in the above array
|
# let the mmr utility pick the *indices* in the above array
|
||||||
mmrChosenIndices = maximal_marginal_relevance(
|
mmr_chosen_indices = maximal_marginal_relevance(
|
||||||
np.array(embedding, dtype=np.float32),
|
np.array(embedding, dtype=np.float32),
|
||||||
[pfHit["embedding_vector"] for pfHit in prefetchHits],
|
[pf_hit["vector"] for pf_hit in prefetch_hits],
|
||||||
k=k,
|
k=k,
|
||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
)
|
)
|
||||||
mmrHits = [
|
mmr_hits = [
|
||||||
pfHit
|
pf_hit
|
||||||
for pfIndex, pfHit in enumerate(prefetchHits)
|
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||||
if pfIndex in mmrChosenIndices
|
if pf_index in mmr_chosen_indices
|
||||||
]
|
]
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=hit["document"],
|
page_content=hit["body_blob"],
|
||||||
metadata=hit["metadata"],
|
metadata=hit["metadata"],
|
||||||
)
|
)
|
||||||
for hit in mmrHits
|
for hit in mmr_hits
|
||||||
]
|
]
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
|
Loading…
Reference in New Issue
Block a user