mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
retrievers: MultiVectorRetriever similarity_score_threshold search type (#23539)
Description: support MultiVectorRetriever similarity_score_threshold search type. Issue: #23387 #19404 --------- Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
parent
20151384d7
commit
a001037319
@ -19,6 +19,8 @@ class SearchType(str, Enum):
|
|||||||
|
|
||||||
similarity = "similarity"
|
similarity = "similarity"
|
||||||
"""Similarity search."""
|
"""Similarity search."""
|
||||||
|
similarity_score_threshold = "similarity_score_threshold"
|
||||||
|
"""Similarity search with a score threshold."""
|
||||||
mmr = "mmr"
|
mmr = "mmr"
|
||||||
"""Maximal Marginal Relevance reranking of similarity search."""
|
"""Maximal Marginal Relevance reranking of similarity search."""
|
||||||
|
|
||||||
@ -64,6 +66,13 @@ class MultiVectorRetriever(BaseRetriever):
|
|||||||
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
)
|
)
|
||||||
|
elif self.search_type == SearchType.similarity_score_threshold:
|
||||||
|
sub_docs_and_similarities = (
|
||||||
|
self.vectorstore.similarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||||
else:
|
else:
|
||||||
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
|
|
||||||
@ -89,6 +98,13 @@ class MultiVectorRetriever(BaseRetriever):
|
|||||||
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
)
|
)
|
||||||
|
elif self.search_type == SearchType.similarity_score_threshold:
|
||||||
|
sub_docs_and_similarities = (
|
||||||
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||||
else:
|
else:
|
||||||
sub_docs = await self.vectorstore.asimilarity_search(
|
sub_docs = await self.vectorstore.asimilarity_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
from typing import Any, List
|
from typing import Any, Callable, List, Tuple
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
|
||||||
from langchain.storage import InMemoryStore
|
from langchain.storage import InMemoryStore
|
||||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||||
|
|
||||||
|
|
||||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
@staticmethod
|
||||||
|
def _identity_fn(score: float) -> float:
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
return self._identity_fn
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -16,6 +23,14 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
|||||||
return []
|
return []
|
||||||
return [res]
|
return [res]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
res = self.store.get(query)
|
||||||
|
if res is None:
|
||||||
|
return []
|
||||||
|
return [(res, 0.8)]
|
||||||
|
|
||||||
|
|
||||||
def test_multi_vector_retriever_initialization() -> None:
|
def test_multi_vector_retriever_initialization() -> None:
|
||||||
vectorstore = InMemoryVectorstoreWithSearch()
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
@ -41,3 +56,65 @@ async def test_multi_vector_retriever_initialization_async() -> None:
|
|||||||
results = await retriever.ainvoke("1")
|
results = await retriever.ainvoke("1")
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
assert results[0].page_content == "test document"
|
assert results[0].page_content == "test document"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
||||||
|
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||||
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
|
vectorstore.add_documents(documents, ids=["1"])
|
||||||
|
|
||||||
|
# score_threshold = 0.5
|
||||||
|
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=InMemoryStore(),
|
||||||
|
doc_id="doc_id",
|
||||||
|
search_kwargs={"score_threshold": 0.5},
|
||||||
|
search_type=SearchType.similarity_score_threshold,
|
||||||
|
)
|
||||||
|
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||||
|
results = retriever.invoke("1")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].page_content == "test document"
|
||||||
|
|
||||||
|
# score_threshold = 0.9
|
||||||
|
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=InMemoryStore(),
|
||||||
|
doc_id="doc_id",
|
||||||
|
search_kwargs={"score_threshold": 0.9},
|
||||||
|
search_type=SearchType.similarity_score_threshold,
|
||||||
|
)
|
||||||
|
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||||
|
results = retriever.invoke("1")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_multi_vector_retriever_similarity_search_with_score_async() -> None:
|
||||||
|
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||||
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
|
await vectorstore.aadd_documents(documents, ids=["1"])
|
||||||
|
|
||||||
|
# score_threshold = 0.5
|
||||||
|
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=InMemoryStore(),
|
||||||
|
doc_id="doc_id",
|
||||||
|
search_kwargs={"score_threshold": 0.5},
|
||||||
|
search_type=SearchType.similarity_score_threshold,
|
||||||
|
)
|
||||||
|
await retriever.docstore.amset(list(zip(["1"], documents)))
|
||||||
|
results = retriever.invoke("1")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].page_content == "test document"
|
||||||
|
|
||||||
|
# score_threshold = 0.9
|
||||||
|
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=InMemoryStore(),
|
||||||
|
doc_id="doc_id",
|
||||||
|
search_kwargs={"score_threshold": 0.9},
|
||||||
|
search_type=SearchType.similarity_score_threshold,
|
||||||
|
)
|
||||||
|
await retriever.docstore.amset(list(zip(["1"], documents)))
|
||||||
|
results = retriever.invoke("1")
|
||||||
|
assert len(results) == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user