Harrison/multi vector (#9700)

This commit is contained in:
Harrison Chase
2023-08-24 06:42:42 -07:00
committed by GitHub
parent b048236c1a
commit 9963b32e59
5 changed files with 411 additions and 35 deletions

View File

@@ -40,6 +40,7 @@ from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.milvus import MilvusRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pubmed import PubMedRetriever
@@ -92,4 +93,5 @@ __all__ = [
"WebResearchRetriever",
"EnsembleRetriever",
"ParentDocumentRetriever",
"MultiVectorRetriever",
]

View File

@@ -0,0 +1,39 @@
from typing import List
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, BaseStore, Document
from langchain.vectorstores import VectorStore
class MultiVectorRetriever(BaseRetriever):
"""Retrieve from a set of multiple embeddings for the same document."""
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
docstore: BaseStore[str, Document]
"""The storage layer for the parent documents"""
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]

View File

@@ -1,16 +1,12 @@
import uuid
from typing import List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Field
from langchain.retrievers import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.schema.storage import BaseStore
from langchain.text_splitter import TextSplitter
from langchain.vectorstores.base import VectorStore
class ParentDocumentRetriever(BaseRetriever):
class ParentDocumentRetriever(MultiVectorRetriever):
"""Retrieve small chunks then retrieve their parent documents.
When splitting documents for retrieval, there are often conflicting desires:
@@ -59,40 +55,14 @@ class ParentDocumentRetriever(BaseRetriever):
)
"""
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
docstore: BaseStore[str, Document]
"""The storage layer for the parent documents"""
child_splitter: TextSplitter
"""The text splitter to use to create child documents."""
id_key: str = "doc_id"
"""The key to use to track the parent id. This will be stored in the
metadata of child documents."""
parent_splitter: Optional[TextSplitter] = None
"""The text splitter to use to create parent documents.
If none, then the parent documents will be the raw documents passed in."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
def add_documents(
self,