community: add metadata filter to CassandraGraphVectorStore (#25663)

- **Description:** 
- Added metadata filtering support to
`langchain_community.graph_vectorstores.cassandra.CassandraGraphVectorStore`
  - Also fixed type conversion issues highlighted by mypy.
- **Dependencies:** 
  - `ragstack-ai-knowledge-store 0.2.0` (released July 23, 2024)

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Eric Pinzur 2024-08-22 20:27:16 +02:00 committed by GitHub
parent 5b9290a449
commit 01ded5e2f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,18 +120,31 @@ class CassandraGraphVectorStore(GraphVectorStore):
return store return store
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self,
query: str,
k: int = 4,
metadata_filter: dict[str, Any] = {},
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
embedding_vector = self._embedding.embed_query(query) embedding_vector = self._embedding.embed_query(query)
return self.similarity_search_by_vector( return self.similarity_search_by_vector(
embedding_vector, embedding_vector,
k=k, k=k,
metadata_filter=metadata_filter,
) )
def similarity_search_by_vector( def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self,
embedding: List[float],
k: int = 4,
metadata_filter: dict[str, Any] = {},
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
nodes = self.store.similarity_search(embedding, k=k) nodes = self.store.similarity_search(
embedding,
k=k,
metadata_filter=metadata_filter,
)
return list(nodes_to_documents(nodes)) return list(nodes_to_documents(nodes))
def traversal_search( def traversal_search(
@ -140,9 +153,15 @@ class CassandraGraphVectorStore(GraphVectorStore):
*, *,
k: int = 4, k: int = 4,
depth: int = 1, depth: int = 1,
metadata_filter: dict[str, Any] = {},
**kwargs: Any, **kwargs: Any,
) -> Iterable[Document]: ) -> Iterable[Document]:
nodes = self.store.traversal_search(query, k=k, depth=depth) nodes = self.store.traversal_search(
query,
k=k,
depth=depth,
metadata_filter=metadata_filter,
)
return nodes_to_documents(nodes) return nodes_to_documents(nodes)
def mmr_traversal_search( def mmr_traversal_search(
@ -155,6 +174,7 @@ class CassandraGraphVectorStore(GraphVectorStore):
adjacent_k: int = 10, adjacent_k: int = 10,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
score_threshold: float = float("-inf"), score_threshold: float = float("-inf"),
metadata_filter: dict[str, Any] = {},
**kwargs: Any, **kwargs: Any,
) -> Iterable[Document]: ) -> Iterable[Document]:
nodes = self.store.mmr_traversal_search( nodes = self.store.mmr_traversal_search(
@ -165,5 +185,6 @@ class CassandraGraphVectorStore(GraphVectorStore):
adjacent_k=adjacent_k, adjacent_k=adjacent_k,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
score_threshold=score_threshold, score_threshold=score_threshold,
metadata_filter=metadata_filter,
) )
return nodes_to_documents(nodes) return nodes_to_documents(nodes)