diff --git a/libs/community/langchain_community/graph_vectorstores/cassandra.py b/libs/community/langchain_community/graph_vectorstores/cassandra.py index 6fb04c60a13..33fc5ba8f91 100644 --- a/libs/community/langchain_community/graph_vectorstores/cassandra.py +++ b/libs/community/langchain_community/graph_vectorstores/cassandra.py @@ -120,18 +120,31 @@ class CassandraGraphVectorStore(GraphVectorStore): return store def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + metadata_filter: dict[str, Any] = {}, + **kwargs: Any, ) -> List[Document]: embedding_vector = self._embedding.embed_query(query) return self.similarity_search_by_vector( embedding_vector, k=k, + metadata_filter=metadata_filter, ) def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any + self, + embedding: List[float], + k: int = 4, + metadata_filter: dict[str, Any] = {}, + **kwargs: Any, ) -> List[Document]: - nodes = self.store.similarity_search(embedding, k=k) + nodes = self.store.similarity_search( + embedding, + k=k, + metadata_filter=metadata_filter, + ) return list(nodes_to_documents(nodes)) def traversal_search( @@ -140,9 +153,15 @@ class CassandraGraphVectorStore(GraphVectorStore): *, k: int = 4, depth: int = 1, + metadata_filter: dict[str, Any] = {}, **kwargs: Any, ) -> Iterable[Document]: - nodes = self.store.traversal_search(query, k=k, depth=depth) + nodes = self.store.traversal_search( + query, + k=k, + depth=depth, + metadata_filter=metadata_filter, + ) return nodes_to_documents(nodes) def mmr_traversal_search( @@ -155,6 +174,7 @@ class CassandraGraphVectorStore(GraphVectorStore): adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), + metadata_filter: dict[str, Any] = {}, **kwargs: Any, ) -> Iterable[Document]: nodes = self.store.mmr_traversal_search( @@ -165,5 +185,6 @@ class CassandraGraphVectorStore(GraphVectorStore): adjacent_k=adjacent_k, lambda_mult=lambda_mult, score_threshold=score_threshold, + metadata_filter=metadata_filter, ) return nodes_to_documents(nodes)