langchain[patch], docs[patch]: use byte store in multivectorretriever (#14474)

This commit is contained in:
Erick Friis
2023-12-08 16:26:11 -08:00
committed by GitHub
parent 1ef13661b9
commit c24f277b7c
2 changed files with 26 additions and 26 deletions

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Any, List, Optional
from typing import Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, validator
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
@@ -26,7 +26,7 @@ class MultiVectorRetriever(BaseRetriever):
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
byte_store: Optional[ByteStore]
byte_store: Optional[ByteStore] = None
"""The lower-level backing storage layer for the parent documents"""
docstore: BaseStore[str, Document]
"""The storage interface for the parent documents"""
@@ -36,16 +36,16 @@ class MultiVectorRetriever(BaseRetriever):
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""
@validator("docstore", pre=True, always=True)
def shim_docstore(
cls, docstore: Optional[BaseStore[str, Document]], values: Any
) -> BaseStore[str, Document]:
@root_validator(pre=True)
def shim_docstore(cls, values: Dict) -> Dict:
byte_store = values.get("byte_store")
docstore = values.get("docstore")
if byte_store is not None:
docstore = create_kv_docstore(byte_store)
elif docstore is None:
raise Exception("You must pass a `byte_store` parameter.")
return docstore
values["docstore"] = docstore
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun