mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
core: add kwargs support to VectorStore (#25934)
has been missing the passthrough until now --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
87c50f99e5
commit
de7996c2ca
@ -855,7 +855,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
return cast(GraphVectorStore, self.vectorstore)
|
return cast(GraphVectorStore, self.vectorstore)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.search_type == "traversal":
|
if self.search_type == "traversal":
|
||||||
return list(
|
return list(
|
||||||
@ -869,7 +869,11 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.search_type == "traversal":
|
if self.search_type == "traversal":
|
||||||
return [
|
return [
|
||||||
|
@ -1459,59 +1459,59 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): # type: ignore[override]
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
_kwargs = self.search_kwargs | kwargs
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||||
elif self.search_type == "similarity_distance_threshold":
|
elif self.search_type == "similarity_distance_threshold":
|
||||||
if self.search_kwargs["distance_threshold"] is None:
|
if _kwargs["distance_threshold"] is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"distance_threshold must be provided for "
|
"distance_threshold must be provided for "
|
||||||
+ "similarity_distance_threshold retriever"
|
+ "similarity_distance_threshold retriever"
|
||||||
)
|
)
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||||
|
|
||||||
elif self.search_type == "similarity_score_threshold":
|
elif self.search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = (
|
docs_and_similarities = (
|
||||||
self.vectorstore.similarity_search_with_relevance_scores(
|
self.vectorstore.similarity_search_with_relevance_scores(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docs = [doc for doc, _ in docs_and_similarities]
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
elif self.search_type == "mmr":
|
elif self.search_type == "mmr":
|
||||||
docs = self.vectorstore.max_marginal_relevance_search(
|
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
_kwargs = self.search_kwargs | kwargs
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = await self.vectorstore.asimilarity_search(
|
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
elif self.search_type == "similarity_distance_threshold":
|
elif self.search_type == "similarity_distance_threshold":
|
||||||
if self.search_kwargs["distance_threshold"] is None:
|
if _kwargs["distance_threshold"] is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"distance_threshold must be provided for "
|
"distance_threshold must be provided for "
|
||||||
+ "similarity_distance_threshold retriever"
|
+ "similarity_distance_threshold retriever"
|
||||||
)
|
)
|
||||||
docs = await self.vectorstore.asimilarity_search(
|
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
elif self.search_type == "similarity_score_threshold":
|
elif self.search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = (
|
docs_and_similarities = (
|
||||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docs = [doc for doc, _ in docs_and_similarities]
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
elif self.search_type == "mmr":
|
elif self.search_type == "mmr":
|
||||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
@ -745,9 +745,9 @@ class VectaraRetriever(VectorStoreRetriever): # type: ignore[override]
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
docs_and_scores = self.vectorstore.vectara_query(query, self.config)
|
docs_and_scores = self.vectorstore.vectara_query(query, self.config, **kwargs)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Unit tests for the PebbloRetrievalQA chain
|
Unit tests for the PebbloRetrievalQA chain
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -35,12 +35,16 @@ class FakeRetriever(VectorStoreRetriever):
|
|||||||
vectorstore: VectorStore = Mock()
|
vectorstore: VectorStore = Mock()
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return [Document(page_content=query)]
|
return [Document(page_content=query)]
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return [Document(page_content=query)]
|
return [Document(page_content=query)]
|
||||||
|
|
||||||
|
@ -1057,7 +1057,9 @@ class VectorStoreRetriever(BaseRetriever):
|
|||||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
||||||
"""Get standard params for tracing."""
|
"""Get standard params for tracing."""
|
||||||
|
|
||||||
ls_params = super()._get_ls_params(**kwargs)
|
_kwargs = self.search_kwargs | kwargs
|
||||||
|
|
||||||
|
ls_params = super()._get_ls_params(**_kwargs)
|
||||||
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
|
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
|
||||||
|
|
||||||
if self.vectorstore.embeddings:
|
if self.vectorstore.embeddings:
|
||||||
@ -1074,43 +1076,45 @@ class VectorStoreRetriever(BaseRetriever):
|
|||||||
return ls_params
|
return ls_params
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
|
_kwargs = self.search_kwargs | kwargs
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **_kwargs)
|
||||||
elif self.search_type == "similarity_score_threshold":
|
elif self.search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = (
|
docs_and_similarities = (
|
||||||
self.vectorstore.similarity_search_with_relevance_scores(
|
self.vectorstore.similarity_search_with_relevance_scores(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docs = [doc for doc, _ in docs_and_similarities]
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
elif self.search_type == "mmr":
|
elif self.search_type == "mmr":
|
||||||
docs = self.vectorstore.max_marginal_relevance_search(
|
docs = self.vectorstore.max_marginal_relevance_search(query, **_kwargs)
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
msg = f"search_type of {self.search_type} not allowed."
|
msg = f"search_type of {self.search_type} not allowed."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
|
_kwargs = self.search_kwargs | kwargs
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = await self.vectorstore.asimilarity_search(
|
docs = await self.vectorstore.asimilarity_search(query, **_kwargs)
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
elif self.search_type == "similarity_score_threshold":
|
elif self.search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = (
|
docs_and_similarities = (
|
||||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
docs = [doc for doc, _ in docs_and_similarities]
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
elif self.search_type == "mmr":
|
elif self.search_type == "mmr":
|
||||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||||
query, **self.search_kwargs
|
query, **_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = f"search_type of {self.search_type} not allowed."
|
msg = f"search_type of {self.search_type} not allowed."
|
||||||
|
Loading…
Reference in New Issue
Block a user