community: fixed bug in GraphVectorStoreRetriever (#27846)

Description:

This fixes an issue that mistakenly created in
https://github.com/langchain-ai/langchain/pull/27253. The issue
currently exists only in `langchain-community==0.3.4`.

Test cases were added to prevent this issue in the future.

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Eric Pinzur
2024-11-04 21:27:17 +01:00
committed by GitHub
parent eecf95df9b
commit 8eb38622a6
2 changed files with 36 additions and 7 deletions

View File

@@ -440,6 +440,17 @@ class TestCassandraGraphVectorStore:
ts_labels = {doc.metadata["label"] for doc in ts_response}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
# verify the same works as a retriever
retriever = g_store.as_retriever(
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
)
ts_labels = {
doc.metadata["label"]
for doc in retriever.get_relevant_documents(query="[2, 10]")
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
async def test_gvs_traversal_search_async(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,
@@ -453,6 +464,17 @@ class TestCassandraGraphVectorStore:
# so ordering is not deterministic:
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
# verify the same works as a retriever
retriever = g_store.as_retriever(
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
)
ts_labels = {
doc.metadata["label"]
for doc in await retriever.aget_relevant_documents(query="[2, 10]")
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
def test_gvs_mmr_traversal_search_sync(
self,
populated_graph_vector_store_d2: CassandraGraphVectorStore,