mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
community[patch]: Add missing async similarity_distance_threshold handling in RedisVectorStoreRetriever (#16359)
Add missing async similarity_distance_threshold handling in RedisVectorStoreRetriever - **Description:** added method `_aget_relevant_documents` to `RedisVectorStoreRetriever` that overrides parent method to add support of `similarity_distance_threshold` in async mode (as for sync mode) - **Issue:** #16099 - **Dependencies:** N/A - **Twitter handle:** N/A
This commit is contained in:
parent
7c6a2a8384
commit
e9d3527b79
@ -23,7 +23,10 @@ from typing import (
|
||||
import numpy as np
|
||||
import yaml
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
@ -1464,6 +1467,37 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == "similarity_distance_threshold":
|
||||
if self.search_kwargs["distance_threshold"] is None:
|
||||
raise ValueError(
|
||||
"distance_threshold must be provided for "
|
||||
+ "similarity_distance_threshold retriever"
|
||||
)
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user