mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community: fixed bug in GraphVectorStoreRetriever (#27846)
Description: This fixes an issue that mistakenly created in https://github.com/langchain-ai/langchain/pull/27253. The issue currently exists only in `langchain-community==0.3.4`. Test cases were added to prevent this issue in the future. Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
eecf95df9b
commit
8eb38622a6
@ -8,6 +8,7 @@ from typing import (
|
||||
ClassVar,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
@ -701,7 +702,7 @@ class GraphVectorStore(VectorStore):
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
|
||||
@ -837,8 +838,8 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
|
||||
""" # noqa: E501
|
||||
|
||||
vector_store: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
vectorstore: VectorStore
|
||||
"""VectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
@ -849,14 +850,20 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
@property
|
||||
def graph_vectorstore(self) -> GraphVectorStore:
|
||||
return cast(GraphVectorStore, self.vectorstore)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
|
||||
return list(
|
||||
self.graph_vectorstore.traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
|
||||
self.graph_vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
@ -867,14 +874,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vector_store.atraversal_search(
|
||||
async for doc in self.graph_vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vector_store.ammr_traversal_search(
|
||||
async for doc in self.graph_vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
|
@ -440,6 +440,17 @@ class TestCassandraGraphVectorStore:
|
||||
ts_labels = {doc.metadata["label"] for doc in ts_response}
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
# verify the same works as a retriever
|
||||
retriever = g_store.as_retriever(
|
||||
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
|
||||
)
|
||||
|
||||
ts_labels = {
|
||||
doc.metadata["label"]
|
||||
for doc in retriever.get_relevant_documents(query="[2, 10]")
|
||||
}
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
async def test_gvs_traversal_search_async(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
@ -453,6 +464,17 @@ class TestCassandraGraphVectorStore:
|
||||
# so ordering is not deterministic:
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
# verify the same works as a retriever
|
||||
retriever = g_store.as_retriever(
|
||||
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
|
||||
)
|
||||
|
||||
ts_labels = {
|
||||
doc.metadata["label"]
|
||||
for doc in await retriever.aget_relevant_documents(query="[2, 10]")
|
||||
}
|
||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||
|
||||
def test_gvs_mmr_traversal_search_sync(
|
||||
self,
|
||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||
|
Loading…
Reference in New Issue
Block a user