mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 03:19:38 +00:00
Allow base_store to be used directly with MultiVectorRetriever (#14202)
Allow users to pass a generic `BaseStore[str, bytes]` to MultiVectorRetriever, removing the need to use the `create_kv_docstore` method. This encoding will now happen internally. @rlancemartin @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.storage._lc_store import create_kv_docstore
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
@@ -27,12 +27,35 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
and their embedding vectors"""
|
||||
docstore: BaseStore[str, Document]
|
||||
"""The storage layer for the parent documents"""
|
||||
id_key: str = "doc_id"
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
id_key: str
|
||||
search_kwargs: dict
|
||||
"""Keyword arguments to pass to the search function."""
|
||||
search_type: SearchType = SearchType.similarity
|
||||
search_type: SearchType
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vectorstore: VectorStore,
|
||||
docstore: Optional[BaseStore[str, Document]] = None,
|
||||
base_store: Optional[BaseStore[str, bytes]] = None,
|
||||
id_key: str = "doc_id",
|
||||
search_kwargs: Optional[dict] = None,
|
||||
search_type: SearchType = SearchType.similarity,
|
||||
):
|
||||
if base_store is not None:
|
||||
docstore = create_kv_docstore(base_store)
|
||||
elif docstore is None:
|
||||
raise Exception("You must pass a `base_store` parameter.")
|
||||
|
||||
super().__init__(
|
||||
vectorstore=vectorstore,
|
||||
docstore=docstore,
|
||||
id_key=id_key,
|
||||
search_kwargs=search_kwargs if search_kwargs is not None else {},
|
||||
search_type=search_type,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
|
Reference in New Issue
Block a user