From a001037319537e8196b85e8704bb5c241055fe90 Mon Sep 17 00:00:00 2001 From: wenngong <76683249+wenngong@users.noreply.github.com> Date: Tue, 16 Jul 2024 01:31:34 +0800 Subject: [PATCH] retrievers: MultiVectorRetriever similarity_score_threshold search type (#23539) Description: support MultiVectorRetriever similarity_score_threshold search type. Issue: #23387 #19404 --------- Co-authored-by: gongwn1 --- .../langchain/retrievers/multi_vector.py | 16 ++++ .../retrievers/test_multi_vector.py | 81 ++++++++++++++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 2ac989914ae..54a4d935dcd 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -19,6 +19,8 @@ class SearchType(str, Enum): similarity = "similarity" """Similarity search.""" + similarity_score_threshold = "similarity_score_threshold" + """Similarity search with a score threshold.""" mmr = "mmr" """Maximal Marginal Relevance reranking of similarity search.""" @@ -64,6 +66,13 @@ class MultiVectorRetriever(BaseRetriever): sub_docs = self.vectorstore.max_marginal_relevance_search( query, **self.search_kwargs ) + elif self.search_type == SearchType.similarity_score_threshold: + sub_docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] else: sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) @@ -89,6 +98,13 @@ class MultiVectorRetriever(BaseRetriever): sub_docs = await self.vectorstore.amax_marginal_relevance_search( query, **self.search_kwargs ) + elif self.search_type == SearchType.similarity_score_threshold: + sub_docs_and_similarities = ( + await self.vectorstore.asimilarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] else: sub_docs = await self.vectorstore.asimilarity_search( query, **self.search_kwargs diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py index e35244c77d4..2fdc8009fb7 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py @@ -1,13 +1,20 @@ -from typing import Any, List +from typing import Any, Callable, List, Tuple from langchain_core.documents import Document -from langchain.retrievers.multi_vector import MultiVectorRetriever +from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType from langchain.storage import InMemoryStore from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore class InMemoryVectorstoreWithSearch(InMemoryVectorStore): + @staticmethod + def _identity_fn(score: float) -> float: + return score + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + return self._identity_fn + def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: @@ -16,6 +23,14 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return [] return [res] + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + res = self.store.get(query) + if res is None: + return [] + return [(res, 0.8)] + def test_multi_vector_retriever_initialization() -> None: vectorstore = InMemoryVectorstoreWithSearch() @@ -41,3 +56,65 @@ async def test_multi_vector_retriever_initialization_async() -> None: results = await retriever.ainvoke("1") assert len(results) > 0 assert results[0].page_content == "test document" + + +def test_multi_vector_retriever_similarity_search_with_score() -> None: + documents = [Document(page_content="test document", metadata={"doc_id": "1"})] + vectorstore = InMemoryVectorstoreWithSearch() + vectorstore.add_documents(documents, ids=["1"]) + + # score_threshold = 0.5 + retriever = MultiVectorRetriever( # type: ignore[call-arg] + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", + search_kwargs={"score_threshold": 0.5}, + search_type=SearchType.similarity_score_threshold, + ) + retriever.docstore.mset(list(zip(["1"], documents))) + results = retriever.invoke("1") + assert len(results) == 1 + assert results[0].page_content == "test document" + + # score_threshold = 0.9 + retriever = MultiVectorRetriever( # type: ignore[call-arg] + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", + search_kwargs={"score_threshold": 0.9}, + search_type=SearchType.similarity_score_threshold, + ) + retriever.docstore.mset(list(zip(["1"], documents))) + results = retriever.invoke("1") + assert len(results) == 0 + + +async def test_multi_vector_retriever_similarity_search_with_score_async() -> None: + documents = [Document(page_content="test document", metadata={"doc_id": "1"})] + vectorstore = InMemoryVectorstoreWithSearch() + await vectorstore.aadd_documents(documents, ids=["1"]) + + # score_threshold = 0.5 + retriever = MultiVectorRetriever( # type: ignore[call-arg] + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", + search_kwargs={"score_threshold": 0.5}, + search_type=SearchType.similarity_score_threshold, + ) + await retriever.docstore.amset(list(zip(["1"], documents))) + results = retriever.invoke("1") + assert len(results) == 1 + assert results[0].page_content == "test document" + + # score_threshold = 0.9 + retriever = MultiVectorRetriever( # type: ignore[call-arg] + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", + search_kwargs={"score_threshold": 0.9}, + search_type=SearchType.similarity_score_threshold, + ) + await retriever.docstore.amset(list(zip(["1"], documents))) + results = retriever.invoke("1") + assert len(results) == 0