mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community: fixed bug in GraphVectorStoreRetriever (#27846)
Description: This fixes an issue that mistakenly created in https://github.com/langchain-ai/langchain/pull/27253. The issue currently exists only in `langchain-community==0.3.4`. Test cases were added to prevent this issue in the future. Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
eecf95df9b
commit
8eb38622a6
@ -8,6 +8,7 @@ from typing import (
|
|||||||
ClassVar,
|
ClassVar,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
@ -701,7 +702,7 @@ class GraphVectorStore(VectorStore):
|
|||||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
|
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
|
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
|
||||||
@ -837,8 +838,8 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
|
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
vector_store: GraphVectorStore
|
vectorstore: VectorStore
|
||||||
"""GraphVectorStore to use for retrieval."""
|
"""VectorStore to use for retrieval."""
|
||||||
search_type: str = "traversal"
|
search_type: str = "traversal"
|
||||||
"""Type of search to perform. Defaults to "traversal"."""
|
"""Type of search to perform. Defaults to "traversal"."""
|
||||||
allowed_search_types: ClassVar[Collection[str]] = (
|
allowed_search_types: ClassVar[Collection[str]] = (
|
||||||
@ -849,14 +850,20 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
"mmr_traversal",
|
"mmr_traversal",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_vectorstore(self) -> GraphVectorStore:
|
||||||
|
return cast(GraphVectorStore, self.vectorstore)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.search_type == "traversal":
|
if self.search_type == "traversal":
|
||||||
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
|
return list(
|
||||||
|
self.graph_vectorstore.traversal_search(query, **self.search_kwargs)
|
||||||
|
)
|
||||||
elif self.search_type == "mmr_traversal":
|
elif self.search_type == "mmr_traversal":
|
||||||
return list(
|
return list(
|
||||||
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
|
self.graph_vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||||
@ -867,14 +874,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
if self.search_type == "traversal":
|
if self.search_type == "traversal":
|
||||||
return [
|
return [
|
||||||
doc
|
doc
|
||||||
async for doc in self.vector_store.atraversal_search(
|
async for doc in self.graph_vectorstore.atraversal_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
elif self.search_type == "mmr_traversal":
|
elif self.search_type == "mmr_traversal":
|
||||||
return [
|
return [
|
||||||
doc
|
doc
|
||||||
async for doc in self.vector_store.ammr_traversal_search(
|
async for doc in self.graph_vectorstore.ammr_traversal_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -440,6 +440,17 @@ class TestCassandraGraphVectorStore:
|
|||||||
ts_labels = {doc.metadata["label"] for doc in ts_response}
|
ts_labels = {doc.metadata["label"] for doc in ts_response}
|
||||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||||
|
|
||||||
|
# verify the same works as a retriever
|
||||||
|
retriever = g_store.as_retriever(
|
||||||
|
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
|
||||||
|
)
|
||||||
|
|
||||||
|
ts_labels = {
|
||||||
|
doc.metadata["label"]
|
||||||
|
for doc in retriever.get_relevant_documents(query="[2, 10]")
|
||||||
|
}
|
||||||
|
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||||
|
|
||||||
async def test_gvs_traversal_search_async(
|
async def test_gvs_traversal_search_async(
|
||||||
self,
|
self,
|
||||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||||
@ -453,6 +464,17 @@ class TestCassandraGraphVectorStore:
|
|||||||
# so ordering is not deterministic:
|
# so ordering is not deterministic:
|
||||||
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||||
|
|
||||||
|
# verify the same works as a retriever
|
||||||
|
retriever = g_store.as_retriever(
|
||||||
|
search_type="traversal", search_kwargs={"k": 2, "depth": 2}
|
||||||
|
)
|
||||||
|
|
||||||
|
ts_labels = {
|
||||||
|
doc.metadata["label"]
|
||||||
|
for doc in await retriever.aget_relevant_documents(query="[2, 10]")
|
||||||
|
}
|
||||||
|
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
|
||||||
|
|
||||||
def test_gvs_mmr_traversal_search_sync(
|
def test_gvs_mmr_traversal_search_sync(
|
||||||
self,
|
self,
|
||||||
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
populated_graph_vector_store_d2: CassandraGraphVectorStore,
|
||||||
|
Loading…
Reference in New Issue
Block a user