mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
langchain[patch]: Add async methods to MultiVectorRetriever (#16878)
Adds async support to multi vector retriever
This commit is contained in:
parent
7d03d8f586
commit
78a1af4848
@ -1,7 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
@ -71,3 +74,30 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
ids.append(d.metadata[self.id_key])
|
||||
docs = self.docstore.mget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if self.search_type == SearchType.mmr:
|
||||
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
sub_docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
|
||||
# We do this to maintain the order of the ids that are returned
|
||||
ids = []
|
||||
for d in sub_docs:
|
||||
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
docs = await self.docstore.amget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
|
@ -28,3 +28,16 @@ def test_multi_vector_retriever_initialization() -> None:
|
||||
results = retriever.invoke("1")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
||||
|
||||
|
||||
async def test_multi_vector_retriever_initialization_async() -> None:
|
||||
vectorstore = InMemoryVectorstoreWithSearch()
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id"
|
||||
)
|
||||
documents = [Document(page_content="test document", metadata={"doc_id": "1"})]
|
||||
await retriever.vectorstore.aadd_documents(documents, ids=["1"])
|
||||
await retriever.docstore.amset(list(zip(["1"], documents)))
|
||||
results = await retriever.ainvoke("1")
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == "test document"
|
||||
|
Loading…
Reference in New Issue
Block a user