mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
core: add kwargs support to VectorStore (#25934)
has been missing the passthrough until now --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
87c50f99e5
commit
de7996c2ca
@ -855,7 +855,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
return cast(GraphVectorStore, self.vectorstore)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(
|
||||
@ -869,7 +869,11 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
|
@ -1459,59 +1459,59 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): # type: ignore[override]
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
_kwargs = self.search_kwargs | kwargs
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||
elif self.search_type == "similarity_distance_threshold":
|
||||
if self.search_kwargs["distance_threshold"] is None:
|
||||
if _kwargs["distance_threshold"] is None:
|
||||
raise ValueError(
|
||||
"distance_threshold must be provided for "
|
||||
+ "similarity_distance_threshold retriever"
|
||||
)
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
_kwargs = self.search_kwargs | kwargs
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||
elif self.search_type == "similarity_distance_threshold":
|
||||
if self.search_kwargs["distance_threshold"] is None:
|
||||
if _kwargs["distance_threshold"] is None:
|
||||
raise ValueError(
|
||||
"distance_threshold must be provided for "
|
||||
+ "similarity_distance_threshold retriever"
|
||||
)
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
|
@ -745,9 +745,9 @@ class VectaraRetriever(VectorStoreRetriever): # type: ignore[override]
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
docs_and_scores = self.vectorstore.vectara_query(query, self.config)
|
||||
docs_and_scores = self.vectorstore.vectara_query(query, self.config, **kwargs)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
|
@ -2,7 +2,7 @@
|
||||
Unit tests for the PebbloRetrievalQA chain
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
@ -35,12 +35,16 @@ class FakeRetriever(VectorStoreRetriever):
|
||||
vectorstore: VectorStore = Mock()
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
||||
|
||||
|
@ -1057,7 +1057,9 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
ls_params = super()._get_ls_params(**kwargs)
|
||||
_kwargs = self.search_kwargs | kwargs
|
||||
|
||||
ls_params = super()._get_ls_params(**_kwargs)
|
||||
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
|
||||
|
||||
if self.vectorstore.embeddings:
|
||||
@ -1074,43 +1076,45 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
return ls_params
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
_kwargs = self.search_kwargs | kwargs
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
|
||||
else:
|
||||
msg = f"search_type of {self.search_type} not allowed."
|
||||
raise ValueError(msg)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
_kwargs = self.search_kwargs | kwargs
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
query, **_kwargs
|
||||
)
|
||||
else:
|
||||
msg = f"search_type of {self.search_type} not allowed."
|
||||
|
Loading…
Reference in New Issue
Block a user