From 78a1af48488ecab2a1a9bb6264d3aefb6c8a6bdc Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 1 Feb 2024 19:33:06 +0100 Subject: [PATCH] langchain[patch]: Add async methods to MultiVectorRetriever (#16878) Adds async support to multi vector retriever --- .../langchain/retrievers/multi_vector.py | 32 ++++++++++++++++++- .../retrievers/test_multi_vector.py | 13 ++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 167fc5d4cb8..2ac989914ae 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -1,7 +1,10 @@ from enum import Enum from typing import Dict, List, Optional -from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever @@ -71,3 +74,30 @@ class MultiVectorRetriever(BaseRetriever): ids.append(d.metadata[self.id_key]) docs = self.docstore.mget(ids) return [d for d in docs if d is not None] + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """Asynchronously get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + if self.search_type == SearchType.mmr: + sub_docs = await self.vectorstore.amax_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + sub_docs = await self.vectorstore.asimilarity_search( + query, **self.search_kwargs + ) + + # We do this to maintain the order of the ids that are returned + ids = [] + for d in sub_docs: + if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: + ids.append(d.metadata[self.id_key]) + docs = await self.docstore.amget(ids) + return [d for d in docs if d is not None] diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py index 0d5f9a18368..648d8252141 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py @@ -28,3 +28,16 @@ def test_multi_vector_retriever_initialization() -> None: results = retriever.invoke("1") assert len(results) > 0 assert results[0].page_content == "test document" + + +async def test_multi_vector_retriever_initialization_async() -> None: + vectorstore = InMemoryVectorstoreWithSearch() + retriever = MultiVectorRetriever( + vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id" + ) + documents = [Document(page_content="test document", metadata={"doc_id": "1"})] + await retriever.vectorstore.aadd_documents(documents, ids=["1"]) + await retriever.docstore.amset(list(zip(["1"], documents))) + results = await retriever.ainvoke("1") + assert len(results) > 0 + assert results[0].page_content == "test document"