community[patch]: AzureSearchVectorStoreRetriever Fixed to account for search_kwargs (#21572)

- **Description:** Fixed `AzureSearchVectorStoreRetriever` to account
for search_kwargs. More explanation is in the mentioned issue.
- **Issue:** #21492

---------

Co-authored-by: MAC <mac@MACs-MacBook-Pro.local>
Co-authored-by: Massimiliano Pronesti <massimiliano.pronesti@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2024-05-23 02:46:41 +05:00 committed by GitHub
parent 45351d1bc6
commit 16617dd239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -848,7 +848,6 @@ class AzureSearch(VectorStore):
"semantic_hybrid". "semantic_hybrid".
search_kwargs (Optional[Dict]): Keyword arguments to pass to the search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like: search function. Can include things like:
k: Amount of documents to return (Default: 4)
score_threshold: Minimum relevance threshold score_threshold: Minimum relevance threshold
for similarity_score_threshold for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
@ -875,6 +874,16 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
or "semantic_hybrid_score_threshold".""" or "semantic_hybrid_score_threshold"."""
k: int = 4 k: int = 4
"""Number of documents to return.""" """Number of documents to return."""
search_kwargs: dict = {}
"""Search params.
score_threshold: Minimum relevance threshold
for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
lambda_mult: Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5)
filter: Filter by document metadata
"""
allowed_search_types: ClassVar[Collection[str]] = ( allowed_search_types: ClassVar[Collection[str]] = (
"similarity", "similarity",
"similarity_score_threshold", "similarity_score_threshold",
@ -907,31 +916,33 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
params = {**self.search_kwargs, **kwargs}
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.vector_search(query, k=self.k, **kwargs) docs = self.vectorstore.vector_search(query, k=self.k, **params)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
docs = [ docs = [
doc doc
for doc, _ in self.vectorstore.similarity_search_with_relevance_scores( for doc, _ in self.vectorstore.similarity_search_with_relevance_scores(
query, k=self.k, **kwargs query, k=self.k, **params
) )
] ]
elif self.search_type == "hybrid": elif self.search_type == "hybrid":
docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs) docs = self.vectorstore.hybrid_search(query, k=self.k, **params)
elif self.search_type == "hybrid_score_threshold": elif self.search_type == "hybrid_score_threshold":
docs = [ docs = [
doc doc
for doc, _ in self.vectorstore.hybrid_search_with_relevance_scores( for doc, _ in self.vectorstore.hybrid_search_with_relevance_scores(
query, k=self.k, **kwargs query, k=self.k, **params
) )
] ]
elif self.search_type == "semantic_hybrid": elif self.search_type == "semantic_hybrid":
docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs) docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **params)
elif self.search_type == "semantic_hybrid_score_threshold": elif self.search_type == "semantic_hybrid_score_threshold":
docs = [ docs = [
doc doc
for doc, _ in self.vectorstore.semantic_hybrid_search_with_score( for doc, _ in self.vectorstore.semantic_hybrid_search_with_score(
query, k=self.k, **kwargs query, k=self.k, **params
) )
] ]
else: else: