From 01ded5e2f974b1282619fe58aace5f445d7fa454 Mon Sep 17 00:00:00 2001 From: Eric Pinzur <2641606+epinzur@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:27:16 +0200 Subject: [PATCH] community: add metadata filter to CassandraGraphVectorStore (#25663) - **Description:** - Added metadata filtering support to `langchain_community.graph_vectorstores.cassandra.CassandraGraphVectorStore` - Also fixed type conversion issues highlighted by mypy. - **Dependencies:** - `ragstack-ai-knowledge-store 0.2.0` (released July 23, 2024) --------- Co-authored-by: Chester Curme --- .../graph_vectorstores/cassandra.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/graph_vectorstores/cassandra.py b/libs/community/langchain_community/graph_vectorstores/cassandra.py index 6fb04c60a13..33fc5ba8f91 100644 --- a/libs/community/langchain_community/graph_vectorstores/cassandra.py +++ b/libs/community/langchain_community/graph_vectorstores/cassandra.py @@ -120,18 +120,31 @@ class CassandraGraphVectorStore(GraphVectorStore): return store def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + metadata_filter: dict[str, Any] = {}, + **kwargs: Any, ) -> List[Document]: embedding_vector = self._embedding.embed_query(query) return self.similarity_search_by_vector( embedding_vector, k=k, + metadata_filter=metadata_filter, ) def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any + self, + embedding: List[float], + k: int = 4, + metadata_filter: dict[str, Any] = {}, + **kwargs: Any, ) -> List[Document]: - nodes = self.store.similarity_search(embedding, k=k) + nodes = self.store.similarity_search( + embedding, + k=k, + metadata_filter=metadata_filter, + ) return list(nodes_to_documents(nodes)) def traversal_search( @@ -140,9 +153,15 @@ class CassandraGraphVectorStore(GraphVectorStore): *, k: int = 4, depth: int = 1, + metadata_filter: dict[str, Any] = {}, **kwargs: Any, ) -> Iterable[Document]: - nodes = self.store.traversal_search(query, k=k, depth=depth) + nodes = self.store.traversal_search( + query, + k=k, + depth=depth, + metadata_filter=metadata_filter, + ) return nodes_to_documents(nodes) def mmr_traversal_search( @@ -155,6 +174,7 @@ class CassandraGraphVectorStore(GraphVectorStore): adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), + metadata_filter: dict[str, Any] = {}, **kwargs: Any, ) -> Iterable[Document]: nodes = self.store.mmr_traversal_search( @@ -165,5 +185,6 @@ class CassandraGraphVectorStore(GraphVectorStore): adjacent_k=adjacent_k, lambda_mult=lambda_mult, score_threshold=score_threshold, + metadata_filter=metadata_filter, ) return nodes_to_documents(nodes)