core: add kwargs support to VectorStore (#25934)

has been missing the passthrough until now

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Harrison Chase 2024-12-16 10:57:57 -08:00 committed by GitHub
parent 87c50f99e5
commit de7996c2ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 38 deletions

View File

@ -855,7 +855,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
return cast(GraphVectorStore, self.vectorstore)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
if self.search_type == "traversal":
return list(
@ -869,7 +869,11 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
return super()._get_relevant_documents(query, run_manager=run_manager)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
if self.search_type == "traversal":
return [

View File

@ -1459,59 +1459,59 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): # type: ignore[override]
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_distance_threshold":
if self.search_kwargs["distance_threshold"] is None:
if _kwargs["distance_threshold"] is None:
raise ValueError(
"distance_threshold must be provided for "
+ "similarity_distance_threshold retriever"
)
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
query, **_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
elif self.search_type == "similarity_distance_threshold":
if self.search_kwargs["distance_threshold"] is None:
if _kwargs["distance_threshold"] is None:
raise ValueError(
"distance_threshold must be provided for "
+ "similarity_distance_threshold retriever"
)
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
query, **_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
query, **_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")

View File

@ -745,9 +745,9 @@ class VectaraRetriever(VectorStoreRetriever): # type: ignore[override]
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
docs_and_scores = self.vectorstore.vectara_query(query, self.config)
docs_and_scores = self.vectorstore.vectara_query(query, self.config, **kwargs)
return [doc for doc, _ in docs_and_scores]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:

View File

@ -2,7 +2,7 @@
Unit tests for the PebbloRetrievalQA chain
"""
from typing import List
from typing import Any, List
from unittest.mock import Mock
import pytest
@ -35,12 +35,16 @@ class FakeRetriever(VectorStoreRetriever):
vectorstore: VectorStore = Mock()
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
return [Document(page_content=query)]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return [Document(page_content=query)]

View File

@ -1057,7 +1057,9 @@ class VectorStoreRetriever(BaseRetriever):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
ls_params = super()._get_ls_params(**kwargs)
_kwargs = self.search_kwargs | kwargs
ls_params = super()._get_ls_params(**_kwargs)
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
if self.vectorstore.embeddings:
@ -1074,43 +1076,45 @@ class VectorStoreRetriever(BaseRetriever):
return ls_params
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
docs = self.vectorstore.similarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
query, **_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
else:
msg = f"search_type of {self.search_type} not allowed."
raise ValueError(msg)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
_kwargs = self.search_kwargs | kwargs
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
query, **_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
query, **_kwargs
)
else:
msg = f"search_type of {self.search_type} not allowed."