mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 11:31:58 +00:00
community: add metadata filter to CassandraGraphVectorStore (#25663)
- **Description:** - Added metadata filtering support to `langchain_community.graph_vectorstores.cassandra.CassandraGraphVectorStore` - Also fixed type conversion issues highlighted by mypy. - **Dependencies:** - `ragstack-ai-knowledge-store 0.2.0` (released July 23, 2024) --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
5b9290a449
commit
01ded5e2f9
@ -120,18 +120,31 @@ class CassandraGraphVectorStore(GraphVectorStore):
|
|||||||
return store
|
return store
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
metadata_filter: dict[str, Any] = {},
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
embedding_vector = self._embedding.embed_query(query)
|
embedding_vector = self._embedding.embed_query(query)
|
||||||
return self.similarity_search_by_vector(
|
return self.similarity_search_by_vector(
|
||||||
embedding_vector,
|
embedding_vector,
|
||||||
k=k,
|
k=k,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
metadata_filter: dict[str, Any] = {},
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
nodes = self.store.similarity_search(embedding, k=k)
|
nodes = self.store.similarity_search(
|
||||||
|
embedding,
|
||||||
|
k=k,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
)
|
||||||
return list(nodes_to_documents(nodes))
|
return list(nodes_to_documents(nodes))
|
||||||
|
|
||||||
def traversal_search(
|
def traversal_search(
|
||||||
@ -140,9 +153,15 @@ class CassandraGraphVectorStore(GraphVectorStore):
|
|||||||
*,
|
*,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
|
metadata_filter: dict[str, Any] = {},
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterable[Document]:
|
) -> Iterable[Document]:
|
||||||
nodes = self.store.traversal_search(query, k=k, depth=depth)
|
nodes = self.store.traversal_search(
|
||||||
|
query,
|
||||||
|
k=k,
|
||||||
|
depth=depth,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
)
|
||||||
return nodes_to_documents(nodes)
|
return nodes_to_documents(nodes)
|
||||||
|
|
||||||
def mmr_traversal_search(
|
def mmr_traversal_search(
|
||||||
@ -155,6 +174,7 @@ class CassandraGraphVectorStore(GraphVectorStore):
|
|||||||
adjacent_k: int = 10,
|
adjacent_k: int = 10,
|
||||||
lambda_mult: float = 0.5,
|
lambda_mult: float = 0.5,
|
||||||
score_threshold: float = float("-inf"),
|
score_threshold: float = float("-inf"),
|
||||||
|
metadata_filter: dict[str, Any] = {},
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterable[Document]:
|
) -> Iterable[Document]:
|
||||||
nodes = self.store.mmr_traversal_search(
|
nodes = self.store.mmr_traversal_search(
|
||||||
@ -165,5 +185,6 @@ class CassandraGraphVectorStore(GraphVectorStore):
|
|||||||
adjacent_k=adjacent_k,
|
adjacent_k=adjacent_k,
|
||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
)
|
)
|
||||||
return nodes_to_documents(nodes)
|
return nodes_to_documents(nodes)
|
||||||
|
Loading…
Reference in New Issue
Block a user