mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
retrievers: MultiVectorRetriever similarity_score_threshold search type (#23539)
Description: support MultiVectorRetriever similarity_score_threshold search type. Issue: #23387 #19404 --------- Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
parent
20151384d7
commit
a001037319
@ -19,6 +19,8 @@ class SearchType(str, Enum):
|
||||
|
||||
similarity = "similarity"
|
||||
"""Similarity search."""
|
||||
similarity_score_threshold = "similarity_score_threshold"
|
||||
"""Similarity search with a score threshold."""
|
||||
mmr = "mmr"
|
||||
"""Maximal Marginal Relevance reranking of similarity search."""
|
||||
|
||||
@ -64,6 +66,13 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == SearchType.similarity_score_threshold:
|
||||
sub_docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||
else:
|
||||
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
|
||||
@ -89,6 +98,13 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == SearchType.similarity_score_threshold:
|
||||
sub_docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||
else:
|
||||
sub_docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
|
@ -1,13 +1,20 @@
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
|
||||
from langchain.storage import InMemoryStore
|
||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
@staticmethod
|
||||
def _identity_fn(score: float) -> float:
|
||||
return score
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._identity_fn
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
@ -16,6 +23,14 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
return []
|
||||
return [res]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
res = self.store.get(query)
|
||||
if res is None:
|
||||
return []
|
||||
return [(res, 0.8)]
|
||||
|
||||
|
||||
def test_multi_vector_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
@ -41,3 +56,65 @@ async def test_multi_vector_retriever_initialization_async() -> None:
|
||||
results = await retriever.ainvoke("1")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
|
||||
def test_multi_vector_retriever_similarity_search_with_score() -> None:
|
||||
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
vectorstore.add_documents(documents, ids=["1"])
|
||||
|
||||
# score_threshold = 0.5
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
search_kwargs={"score_threshold": 0.5},
|
||||
search_type=SearchType.similarity_score_threshold,
|
||||
)
|
||||
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) == 1
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
# score_threshold = 0.9
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
search_kwargs={"score_threshold": 0.9},
|
||||
search_type=SearchType.similarity_score_threshold,
|
||||
)
|
||||
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
async def test_multi_vector_retriever_similarity_search_with_score_async() -> None:
|
||||
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
await vectorstore.aadd_documents(documents, ids=["1"])
|
||||
|
||||
# score_threshold = 0.5
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
search_kwargs={"score_threshold": 0.5},
|
||||
search_type=SearchType.similarity_score_threshold,
|
||||
)
|
||||
await retriever.docstore.amset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) == 1
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
# score_threshold = 0.9
|
||||
retriever = MultiVectorRetriever( # type: ignore[call-arg]
|
||||
vectorstore=vectorstore,
|
||||
docstore=InMemoryStore(),
|
||||
doc_id="doc_id",
|
||||
search_kwargs={"score_threshold": 0.9},
|
||||
search_type=SearchType.similarity_score_threshold,
|
||||
)
|
||||
await retriever.docstore.amset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user