Allow base_store to be used directly with MultiVectorRetriever (#14202)

Allow users to pass a generic `BaseStore[str, bytes]` to
MultiVectorRetriever, removing the need to use the `create_kv_docstore`
method. This encoding will now happen internally.

@rlancemartin @eyurtsev

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Jacob Lee
2023-12-04 14:43:32 -08:00
committed by GitHub
parent 67662564f3
commit a26c4a0930
2 changed files with 63 additions and 48 deletions

View File

@@ -1,13 +1,13 @@
from enum import Enum
from typing import List
from typing import List, Optional
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.storage._lc_store import create_kv_docstore
class SearchType(str, Enum):
@@ -27,12 +27,35 @@ class MultiVectorRetriever(BaseRetriever):
and their embedding vectors"""
docstore: BaseStore[str, Document]
"""The storage layer for the parent documents"""
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
id_key: str
search_kwargs: dict
"""Keyword arguments to pass to the search function."""
search_type: SearchType = SearchType.similarity
search_type: SearchType
"""Type of search to perform (similarity / mmr)"""
def __init__(
self,
*,
vectorstore: VectorStore,
docstore: Optional[BaseStore[str, Document]] = None,
base_store: Optional[BaseStore[str, bytes]] = None,
id_key: str = "doc_id",
search_kwargs: Optional[dict] = None,
search_type: SearchType = SearchType.similarity,
):
if base_store is not None:
docstore = create_kv_docstore(base_store)
elif docstore is None:
raise Exception("You must pass a `base_store` parameter.")
super().__init__(
vectorstore=vectorstore,
docstore=docstore,
id_key=id_key,
search_kwargs=search_kwargs if search_kwargs is not None else {},
search_type=search_type,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: