mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
community[patch]: Add missing async similarity_distance_threshold handling in RedisVectorStoreRetriever (#16359)
Add missing async similarity_distance_threshold handling in RedisVectorStoreRetriever - **Description:** added method `_aget_relevant_documents` to `RedisVectorStoreRetriever` that overrides parent method to add support of `similarity_distance_threshold` in async mode (as for sync mode) - **Issue:** #16099 - **Dependencies:** N/A - **Twitter handle:** N/A
This commit is contained in:
parent
7c6a2a8384
commit
e9d3527b79
@ -23,7 +23,10 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
@ -1464,6 +1467,37 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
if self.search_type == "similarity":
|
||||||
|
docs = await self.vectorstore.asimilarity_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
elif self.search_type == "similarity_distance_threshold":
|
||||||
|
if self.search_kwargs["distance_threshold"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"distance_threshold must be provided for "
|
||||||
|
+ "similarity_distance_threshold retriever"
|
||||||
|
)
|
||||||
|
docs = await self.vectorstore.asimilarity_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
elif self.search_type == "similarity_score_threshold":
|
||||||
|
docs_and_similarities = (
|
||||||
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
|
elif self.search_type == "mmr":
|
||||||
|
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
return docs
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Add documents to vectorstore."""
|
"""Add documents to vectorstore."""
|
||||||
return self.vectorstore.add_documents(documents, **kwargs)
|
return self.vectorstore.add_documents(documents, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user