core: add kwargs support to VectorStore (#25934)

has been missing the passthrough until now

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Harrison Chase 2024-12-16 10:57:57 -08:00 committed by GitHub
parent 87c50f99e5
commit de7996c2ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 38 deletions

View File

@ -855,7 +855,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
return cast(GraphVectorStore, self.vectorstore) return cast(GraphVectorStore, self.vectorstore)
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]: ) -> list[Document]:
if self.search_type == "traversal": if self.search_type == "traversal":
return list( return list(
@ -869,7 +869,11 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
return super()._get_relevant_documents(query, run_manager=run_manager) return super()._get_relevant_documents(query, run_manager=run_manager)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
if self.search_type == "traversal": if self.search_type == "traversal":
return [ return [

View File

@ -1459,59 +1459,59 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): # type: ignore[override]
) )
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_distance_threshold": elif self.search_type == "similarity_distance_threshold":
if self.search_kwargs["distance_threshold"] is None: if _kwargs["distance_threshold"] is None:
raise ValueError( raise ValueError(
"distance_threshold must be provided for " "distance_threshold must be provided for "
+ "similarity_distance_threshold retriever" + "similarity_distance_threshold retriever"
) )
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
docs_and_similarities = ( docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores( self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs query, **_kwargs
) )
) )
docs = [doc for doc, _ in docs_and_similarities] docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr": elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search( docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
query, **self.search_kwargs
)
else: else:
raise ValueError(f"search_type of {self.search_type} not allowed.") raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity": if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search( docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
query, **self.search_kwargs
)
elif self.search_type == "similarity_distance_threshold": elif self.search_type == "similarity_distance_threshold":
if self.search_kwargs["distance_threshold"] is None: if _kwargs["distance_threshold"] is None:
raise ValueError( raise ValueError(
"distance_threshold must be provided for " "distance_threshold must be provided for "
+ "similarity_distance_threshold retriever" + "similarity_distance_threshold retriever"
) )
docs = await self.vectorstore.asimilarity_search( docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
docs_and_similarities = ( docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores( await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs query, **_kwargs
) )
) )
docs = [doc for doc, _ in docs_and_similarities] docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr": elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search( docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs query, **_kwargs
) )
else: else:
raise ValueError(f"search_type of {self.search_type} not allowed.") raise ValueError(f"search_type of {self.search_type} not allowed.")

View File

@ -745,9 +745,9 @@ class VectaraRetriever(VectorStoreRetriever): # type: ignore[override]
) )
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
docs_and_scores = self.vectorstore.vectara_query(query, self.config) docs_and_scores = self.vectorstore.vectara_query(query, self.config, **kwargs)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:

View File

@ -2,7 +2,7 @@
Unit tests for the PebbloRetrievalQA chain Unit tests for the PebbloRetrievalQA chain
""" """
from typing import List from typing import Any, List
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
@ -35,12 +35,16 @@ class FakeRetriever(VectorStoreRetriever):
vectorstore: VectorStore = Mock() vectorstore: VectorStore = Mock()
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
return [Document(page_content=query)] return [Document(page_content=query)]
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
return [Document(page_content=query)] return [Document(page_content=query)]

View File

@ -1057,7 +1057,9 @@ class VectorStoreRetriever(BaseRetriever):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
ls_params = super()._get_ls_params(**kwargs) _kwargs = self.search_kwargs | kwargs
ls_params = super()._get_ls_params(**_kwargs)
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__ ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
if self.vectorstore.embeddings: if self.vectorstore.embeddings:
@ -1074,43 +1076,45 @@ class VectorStoreRetriever(BaseRetriever):
return ls_params return ls_params
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]: ) -> list[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
docs_and_similarities = ( docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores( self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs query, **_kwargs
) )
) )
docs = [doc for doc, _ in docs_and_similarities] docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr": elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search( docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
query, **self.search_kwargs
)
else: else:
msg = f"search_type of {self.search_type} not allowed." msg = f"search_type of {self.search_type} not allowed."
raise ValueError(msg) raise ValueError(msg)
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity": if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search( docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
docs_and_similarities = ( docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores( await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs query, **_kwargs
) )
) )
docs = [doc for doc, _ in docs_and_similarities] docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr": elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search( docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs query, **_kwargs
) )
else: else:
msg = f"search_type of {self.search_type} not allowed." msg = f"search_type of {self.search_type} not allowed."