Fix multi vector retriever subclassing (#14350)

Fixes #14342

@eyurtsev @baskaryan

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jacob Lee
2023-12-06 11:12:50 -08:00
committed by GitHub
parent 7bdfc43766
commit 867ca6d0be
6 changed files with 103 additions and 40 deletions

View File

@@ -1,7 +1,8 @@
from enum import Enum
from typing import List, Optional
from typing import Any, List, Optional
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
@@ -25,36 +26,26 @@ class MultiVectorRetriever(BaseRetriever):
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
byte_store: Optional[ByteStore]
"""The lower-level backing storage layer for the parent documents"""
docstore: BaseStore[str, Document]
"""The storage layer for the parent documents"""
id_key: str
search_kwargs: dict
"""The storage interface for the parent documents"""
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
search_type: SearchType
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""
def __init__(
self,
*,
vectorstore: VectorStore,
docstore: Optional[BaseStore[str, Document]] = None,
base_store: Optional[ByteStore] = None,
id_key: str = "doc_id",
search_kwargs: Optional[dict] = None,
search_type: SearchType = SearchType.similarity,
):
if base_store is not None:
docstore = create_kv_docstore(base_store)
@validator("docstore", pre=True, always=True)
def shim_docstore(
cls, docstore: Optional[BaseStore[str, Document]], values: Any
) -> BaseStore[str, Document]:
byte_store = values.get("byte_store")
if byte_store is not None:
docstore = create_kv_docstore(byte_store)
elif docstore is None:
raise Exception("You must pass a `base_store` parameter.")
super().__init__(
vectorstore=vectorstore,
docstore=docstore,
id_key=id_key,
search_kwargs=search_kwargs if search_kwargs is not None else {},
search_type=search_type,
)
raise Exception("You must pass a `byte_store` parameter.")
return docstore
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@@ -80,7 +80,7 @@ class InMemoryVectorStore(VectorStore):
*,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> None:
) -> List[str]:
"""Add the given documents to the store (insert behavior)."""
if ids and len(ids) != len(documents):
raise ValueError(
@@ -97,6 +97,8 @@ class InMemoryVectorStore(VectorStore):
)
self.store[_id] = document
return list(ids)
async def aadd_documents(
self,
documents: Sequence[Document],

View File

@@ -0,0 +1,30 @@
from typing import Any, List
from langchain_core.documents import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]
def test_multi_vector_retriever_initialization() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
retriever = MultiVectorRetriever(
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
)
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
retriever.vectorstore.add_documents(documents, ids=["1"])
retriever.docstore.mset(list(zip(["1"], documents)))
results = retriever.invoke("1")
assert len(results) > 0
assert results[0].page_content == "test document"

View File

@@ -0,0 +1,40 @@
from typing import Any, List, Sequence
from langchain_core.documents import Document
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import CharacterTextSplitter
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> List[str]:
print(documents)
return super().add_documents(
documents, ids=[f"{i}" for i in range(len(documents))]
)
def test_parent_document_retriever_initialization() -> None:
vectorstore = InMemoryVectorstoreWithSearch()
store = InMemoryStore()
child_splitter = CharacterTextSplitter(chunk_size=400)
documents = [Document(page_content="test document")]
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
)
retriever.add_documents(documents)
results = retriever.invoke("0")
assert len(results) > 0
assert results[0].page_content == "test document"