mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 12:09:58 +00:00
Add async support to multi-query retriever. (#10873)
Added async support to the MultiQueryRetriever class. --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
4eee789dd3
commit
66d5a7e7cf
@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.output_parsers.pydantic import PydanticOutputParser
|
||||
@ -83,6 +87,64 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
parser_key=parser_key,
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Get relevant documents given a user query.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
|
||||
Returns:
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = await self.agenerate_queries(query, run_manager)
|
||||
documents = await self.aretrieve_documents(queries, run_manager)
|
||||
return self.unique_union(documents)
|
||||
|
||||
async def agenerate_queries(
|
||||
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[str]:
|
||||
"""Generate queries based upon user input.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = await self.llm_chain.acall(
|
||||
inputs={"question": question}, callbacks=run_manager.get_child()
|
||||
)
|
||||
lines = getattr(response["text"], self.parser_key, [])
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
|
||||
async def aretrieve_documents(
|
||||
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Run all LLM generated queries.
|
||||
|
||||
Args:
|
||||
queries: query list
|
||||
|
||||
Returns:
|
||||
List of retrieved Documents
|
||||
"""
|
||||
document_lists = await asyncio.gather(
|
||||
*(
|
||||
self.retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child()
|
||||
)
|
||||
for query in queries
|
||||
)
|
||||
)
|
||||
return [doc for docs in document_lists for doc in docs]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
|
Loading…
Reference in New Issue
Block a user