From de7996c2ca71cd6d59c752251d2042a379376351 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 16 Dec 2024 10:57:57 -0800 Subject: [PATCH] core: add kwargs support to VectorStore (#25934) has been missing the passthrough until now --------- Co-authored-by: Erick Friis --- .../graph_vectorstores/base.py | 8 +++-- .../vectorstores/redis/base.py | 36 +++++++++---------- .../vectorstores/vectara.py | 4 +-- .../chains/test_pebblo_retrieval.py | 10 ++++-- libs/core/langchain_core/vectorstores/base.py | 30 +++++++++------- 5 files changed, 50 insertions(+), 38 deletions(-) diff --git a/libs/community/langchain_community/graph_vectorstores/base.py b/libs/community/langchain_community/graph_vectorstores/base.py index 3c0d81b915f..e779af5f141 100644 --- a/libs/community/langchain_community/graph_vectorstores/base.py +++ b/libs/community/langchain_community/graph_vectorstores/base.py @@ -855,7 +855,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): return cast(GraphVectorStore, self.vectorstore) def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> list[Document]: if self.search_type == "traversal": return list( @@ -869,7 +869,11 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): return super()._get_relevant_documents(query, run_manager=run_manager) async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, ) -> list[Document]: if self.search_type == "traversal": return [ diff --git a/libs/community/langchain_community/vectorstores/redis/base.py b/libs/community/langchain_community/vectorstores/redis/base.py index e18660a6e32..a286ddfb35c 100644 --- a/libs/community/langchain_community/vectorstores/redis/base.py +++ b/libs/community/langchain_community/vectorstores/redis/base.py @@ -1459,59 +1459,59 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): # type: ignore[override] ) def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: + _kwargs = self.search_kwargs | kwargs if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + docs = self.vectorstore.similarity_search(query, **_kwargs) elif self.search_type == "similarity_distance_threshold": - if self.search_kwargs["distance_threshold"] is None: + if _kwargs["distance_threshold"] is None: raise ValueError( "distance_threshold must be provided for " + "similarity_distance_threshold retriever" ) - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + docs = self.vectorstore.similarity_search(query, **_kwargs) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs + query, **_kwargs ) ) docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": - docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs - ) + docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs) else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, ) -> List[Document]: + _kwargs = self.search_kwargs | kwargs if self.search_type == "similarity": - docs = await self.vectorstore.asimilarity_search( - query, **self.search_kwargs - ) + docs = await self.vectorstore.asimilarity_search(query, **_kwargs) elif self.search_type == "similarity_distance_threshold": - if self.search_kwargs["distance_threshold"] is None: + if _kwargs["distance_threshold"] is None: raise ValueError( "distance_threshold must be provided for " + "similarity_distance_threshold retriever" ) - docs = await self.vectorstore.asimilarity_search( - query, **self.search_kwargs - ) + docs = await self.vectorstore.asimilarity_search(query, **_kwargs) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( await self.vectorstore.asimilarity_search_with_relevance_scores( - query, **self.search_kwargs + query, **_kwargs ) ) docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": docs = await self.vectorstore.amax_marginal_relevance_search( - query, **self.search_kwargs + query, **_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") diff --git a/libs/community/langchain_community/vectorstores/vectara.py b/libs/community/langchain_community/vectorstores/vectara.py index 7e77c6e6905..be3b2300b7a 100644 --- a/libs/community/langchain_community/vectorstores/vectara.py +++ b/libs/community/langchain_community/vectorstores/vectara.py @@ -745,9 +745,9 @@ class VectaraRetriever(VectorStoreRetriever): # type: ignore[override] ) def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: - docs_and_scores = self.vectorstore.vectara_query(query, self.config) + docs_and_scores = self.vectorstore.vectara_query(query, self.config, **kwargs) return [doc for doc, _ in docs_and_scores] def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: diff --git a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py index a2fb1dbd009..8928ab87b56 100644 --- a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py +++ b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py @@ -2,7 +2,7 @@ Unit tests for the PebbloRetrievalQA chain """ -from typing import List +from typing import Any, List from unittest.mock import Mock import pytest @@ -35,12 +35,16 @@ class FakeRetriever(VectorStoreRetriever): vectorstore: VectorStore = Mock() def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: return [Document(page_content=query)] async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, ) -> List[Document]: return [Document(page_content=query)] diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index c87da224acd..e4e861e76b9 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -1057,7 +1057,9 @@ class VectorStoreRetriever(BaseRetriever): def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: """Get standard params for tracing.""" - ls_params = super()._get_ls_params(**kwargs) + _kwargs = self.search_kwargs | kwargs + + ls_params = super()._get_ls_params(**_kwargs) ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__ if self.vectorstore.embeddings: @@ -1074,43 +1076,45 @@ class VectorStoreRetriever(BaseRetriever): return ls_params def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> list[Document]: + _kwargs = self.search_kwargs | kwargs if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + docs = self.vectorstore.similarity_search(query, **_kwargs) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs + query, **_kwargs ) ) docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": - docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs - ) + docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs) else: msg = f"search_type of {self.search_type} not allowed." raise ValueError(msg) return docs async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, ) -> list[Document]: + _kwargs = self.search_kwargs | kwargs if self.search_type == "similarity": - docs = await self.vectorstore.asimilarity_search( - query, **self.search_kwargs - ) + docs = await self.vectorstore.asimilarity_search(query, **_kwargs) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( await self.vectorstore.asimilarity_search_with_relevance_scores( - query, **self.search_kwargs + query, **_kwargs ) ) docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": docs = await self.vectorstore.amax_marginal_relevance_search( - query, **self.search_kwargs + query, **_kwargs ) else: msg = f"search_type of {self.search_type} not allowed."