mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 11:30:37 +00:00
Fix multi vector retriever subclassing (#14350)
Fixes #14342 @eyurtsev @baskaryan --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Field, validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
@@ -25,36 +26,26 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
vectorstore: VectorStore
|
||||
"""The underlying vectorstore to use to store small chunks
|
||||
and their embedding vectors"""
|
||||
byte_store: Optional[ByteStore]
|
||||
"""The lower-level backing storage layer for the parent documents"""
|
||||
docstore: BaseStore[str, Document]
|
||||
"""The storage layer for the parent documents"""
|
||||
id_key: str
|
||||
search_kwargs: dict
|
||||
"""The storage interface for the parent documents"""
|
||||
id_key: str = "doc_id"
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the search function."""
|
||||
search_type: SearchType
|
||||
search_type: SearchType = SearchType.similarity
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vectorstore: VectorStore,
|
||||
docstore: Optional[BaseStore[str, Document]] = None,
|
||||
base_store: Optional[ByteStore] = None,
|
||||
id_key: str = "doc_id",
|
||||
search_kwargs: Optional[dict] = None,
|
||||
search_type: SearchType = SearchType.similarity,
|
||||
):
|
||||
if base_store is not None:
|
||||
docstore = create_kv_docstore(base_store)
|
||||
@validator("docstore", pre=True, always=True)
|
||||
def shim_docstore(
|
||||
cls, docstore: Optional[BaseStore[str, Document]], values: Any
|
||||
) -> BaseStore[str, Document]:
|
||||
byte_store = values.get("byte_store")
|
||||
if byte_store is not None:
|
||||
docstore = create_kv_docstore(byte_store)
|
||||
elif docstore is None:
|
||||
raise Exception("You must pass a `base_store` parameter.")
|
||||
|
||||
super().__init__(
|
||||
vectorstore=vectorstore,
|
||||
docstore=docstore,
|
||||
id_key=id_key,
|
||||
search_kwargs=search_kwargs if search_kwargs is not None else {},
|
||||
search_type=search_type,
|
||||
)
|
||||
raise Exception("You must pass a `byte_store` parameter.")
|
||||
return docstore
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
|
@@ -80,7 +80,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
*,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> List[str]:
|
||||
"""Add the given documents to the store (insert behavior)."""
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
@@ -97,6 +97,8 @@ class InMemoryVectorStore(VectorStore):
|
||||
)
|
||||
self.store[_id] = document
|
||||
|
||||
return list(ids)
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
|
@@ -0,0 +1,30 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
res = self.store.get(query)
|
||||
if res is None:
|
||||
return []
|
||||
return [res]
|
||||
|
||||
|
||||
def test_multi_vector_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
|
||||
)
|
||||
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||
retriever.vectorstore.add_documents(documents, ids=["1"])
|
||||
retriever.docstore.mset(list(zip(["1"], documents)))
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
@@ -0,0 +1,40 @@
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
res = self.store.get(query)
|
||||
if res is None:
|
||||
return []
|
||||
return [res]
|
||||
|
||||
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> List[str]:
|
||||
print(documents)
|
||||
return super().add_documents(
|
||||
documents, ids=[f"{i}" for i in range(len(documents))]
|
||||
)
|
||||
|
||||
|
||||
def test_parent_document_retriever_initialization() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
store = InMemoryStore()
|
||||
child_splitter = CharacterTextSplitter(chunk_size=400)
|
||||
documents = [Document(page_content="test document")]
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
)
|
||||
retriever.add_documents(documents)
|
||||
results = retriever.invoke("0")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
Reference in New Issue
Block a user