feat(community): support semantic hybrid score threshold in Azure AI Search (#21527)

Support semantic hybrid search with a score threshold -- similar to what
we do for similarity search and for hybrid search (#20907).
This commit is contained in:
Massimiliano Pronesti 2024-05-16 21:54:32 +02:00 committed by GitHub
parent 5e445a7e4e
commit 0c0db7c5db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Literal,
Optional, Optional,
Tuple, Tuple,
Type, Type,
@ -567,7 +568,11 @@ class AzureSearch(VectorStore):
return [doc for doc, _, _ in docs_and_scores] return [doc for doc, _, _ in docs_and_scores]
def semantic_hybrid_search_with_score( def semantic_hybrid_search_with_score(
self, query: str, k: int = 4, **kwargs: Any self,
query: str,
k: int = 4,
score_type: Literal["score", "reranker_score"] = "score",
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
""" """
Returns the most similar indexed documents to the query text. Returns the most similar indexed documents to the query text.
@ -575,14 +580,29 @@ class AzureSearch(VectorStore):
Args: Args:
query (str): The query text for which to find similar documents. query (str): The query text for which to find similar documents.
k (int): The number of documents to return. Default is 4. k (int): The number of documents to return. Default is 4.
score_type: Must either be "score" or "reranker_score".
Defaulted to "score".
Returns: Returns:
List[Document]: A list of documents that are most similar to the query text. List[Tuple[Document, float]]: A list of documents and their
corresponding scores.
""" """
score_threshold = kwargs.pop("score_threshold", None)
docs_and_scores = self.semantic_hybrid_search_with_score_and_rerank( docs_and_scores = self.semantic_hybrid_search_with_score_and_rerank(
query, k=k, filters=kwargs.get("filters", None) query, k=k, filters=kwargs.get("filters", None)
) )
return [(doc, score) for doc, score, _ in docs_and_scores] if score_type == "score":
return [
(doc, score)
for doc, score, _ in docs_and_scores
if score_threshold is None or score >= score_threshold
]
elif score_type == "reranker_score":
return [
(doc, reranker_score)
for doc, _, reranker_score in docs_and_scores
if score_threshold is None or reranker_score >= score_threshold
]
def semantic_hybrid_search_with_score_and_rerank( def semantic_hybrid_search_with_score_and_rerank(
self, query: str, k: int = 4, filters: Optional[str] = None self, query: str, k: int = 4, filters: Optional[str] = None
@ -716,7 +736,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Azure Search instance used to find similar documents.""" """Azure Search instance used to find similar documents."""
search_type: str = "hybrid" search_type: str = "hybrid"
"""Type of search to perform. Options are "similarity", "hybrid", """Type of search to perform. Options are "similarity", "hybrid",
"semantic_hybrid", "similarity_score_threshold", "hybrid_score_threshold".""" "semantic_hybrid", "similarity_score_threshold", "hybrid_score_threshold",
or "semantic_hybrid_score_threshold"."""
k: int = 4 k: int = 4
"""Number of documents to return.""" """Number of documents to return."""
allowed_search_types: ClassVar[Collection[str]] = ( allowed_search_types: ClassVar[Collection[str]] = (
@ -725,6 +746,7 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
"hybrid", "hybrid",
"hybrid_score_threshold", "hybrid_score_threshold",
"semantic_hybrid", "semantic_hybrid",
"semantic_hybrid_score_threshold",
) )
class Config: class Config:
@ -770,6 +792,13 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
] ]
elif self.search_type == "semantic_hybrid": elif self.search_type == "semantic_hybrid":
docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs) docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs)
elif self.search_type == "semantic_hybrid_score_threshold":
docs = [
doc
for doc, _ in self.vectorstore.semantic_hybrid_search_with_score(
query, k=self.k, **kwargs
)
]
else: else:
raise ValueError(f"search_type of {self.search_type} not allowed.") raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs return docs