retrievers: MultiVectorRetriever similarity_score_threshold search type (#23539)

Description: support MultiVectorRetriever similarity_score_threshold
search type.

Issue: #23387 #19404

---------

Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
wenngong 2024-07-16 01:31:34 +08:00 committed by GitHub
parent 20151384d7
commit a001037319
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 95 additions and 2 deletions

View File

@ -19,6 +19,8 @@ class SearchType(str, Enum):
similarity = "similarity" similarity = "similarity"
"""Similarity search.""" """Similarity search."""
similarity_score_threshold = "similarity_score_threshold"
"""Similarity search with a score threshold."""
mmr = "mmr" mmr = "mmr"
"""Maximal Marginal Relevance reranking of similarity search.""" """Maximal Marginal Relevance reranking of similarity search."""
@ -64,6 +66,13 @@ class MultiVectorRetriever(BaseRetriever):
sub_docs = self.vectorstore.max_marginal_relevance_search( sub_docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs query, **self.search_kwargs
) )
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else: else:
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
@ -89,6 +98,13 @@ class MultiVectorRetriever(BaseRetriever):
sub_docs = await self.vectorstore.amax_marginal_relevance_search( sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs query, **self.search_kwargs
) )
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else: else:
sub_docs = await self.vectorstore.asimilarity_search( sub_docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs query, **self.search_kwargs

View File

@ -1,13 +1,20 @@
from typing import Any, List from typing import Any, Callable, List, Tuple
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
from langchain.storage import InMemoryStore from langchain.storage import InMemoryStore
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
class InMemoryVectorstoreWithSearch(InMemoryVectorStore): class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
@staticmethod
def _identity_fn(score: float) -> float:
return score
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._identity_fn
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
@ -16,6 +23,14 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
return [] return []
return [res] return [res]
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
res = self.store.get(query)
if res is None:
return []
return [(res, 0.8)]
def test_multi_vector_retriever_initialization() -> None: def test_multi_vector_retriever_initialization() -> None:
vectorstore = InMemoryVectorstoreWithSearch() vectorstore = InMemoryVectorstoreWithSearch()
@ -41,3 +56,65 @@ async def test_multi_vector_retriever_initialization_async() -> None:
results = await retriever.ainvoke("1") results = await retriever.ainvoke("1")
assert len(results) > 0 assert len(results) > 0
assert results[0].page_content == "test document" assert results[0].page_content == "test document"
def test_multi_vector_retriever_similarity_search_with_score() -> None:
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
vectorstore = InMemoryVectorstoreWithSearch()
vectorstore.add_documents(documents, ids=["1"])
# score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg]
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
search_kwargs={"score_threshold": 0.5},
search_type=SearchType.similarity_score_threshold,
)
retriever.docstore.mset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) == 1
assert results[0].page_content == "test document"
# score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg]
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
search_kwargs={"score_threshold": 0.9},
search_type=SearchType.similarity_score_threshold,
)
retriever.docstore.mset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) == 0
async def test_multi_vector_retriever_similarity_search_with_score_async() -> None:
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
vectorstore = InMemoryVectorstoreWithSearch()
await vectorstore.aadd_documents(documents, ids=["1"])
# score_threshold = 0.5
retriever = MultiVectorRetriever( # type: ignore[call-arg]
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
search_kwargs={"score_threshold": 0.5},
search_type=SearchType.similarity_score_threshold,
)
await retriever.docstore.amset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) == 1
assert results[0].page_content == "test document"
# score_threshold = 0.9
retriever = MultiVectorRetriever( # type: ignore[call-arg]
vectorstore=vectorstore,
docstore=InMemoryStore(),
doc_id="doc_id",
search_kwargs={"score_threshold": 0.9},
search_type=SearchType.similarity_score_threshold,
)
await retriever.docstore.amset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) == 0