mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
Keep also original query - multi_query.py (#12696)
When you use a MultiQuery it might be useful to use the original query as well as the newly generated ones to maximise the changes to retriever the correct document. I haven't created an issue, it seems a very small and easy thing. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4fe9bf70b6
commit
2e2b9c76d9
@ -61,6 +61,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
llm_chain: LLMChain
|
||||
verbose: bool = True
|
||||
parser_key: str = "lines"
|
||||
include_original: bool = False
|
||||
"""Whether to include the original query in the list of generated queries."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -69,12 +71,15 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
llm: BaseLLM,
|
||||
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||
parser_key: str = "lines",
|
||||
include_original: bool = False,
|
||||
) -> "MultiQueryRetriever":
|
||||
"""Initialize from llm using default template.
|
||||
|
||||
Args:
|
||||
retriever: retriever to query documents from
|
||||
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
||||
include_original: Whether to include the original query in the list of
|
||||
generated queries.
|
||||
|
||||
Returns:
|
||||
MultiQueryRetriever
|
||||
@ -85,6 +90,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
parser_key=parser_key,
|
||||
include_original=include_original,
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
@ -102,6 +108,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = await self.agenerate_queries(query, run_manager)
|
||||
if self.include_original:
|
||||
queries.append(query)
|
||||
documents = await self.aretrieve_documents(queries, run_manager)
|
||||
return self.unique_union(documents)
|
||||
|
||||
@ -160,6 +168,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = self.generate_queries(query, run_manager)
|
||||
if self.include_original:
|
||||
queries.append(query)
|
||||
documents = self.retrieve_documents(queries, run_manager)
|
||||
return self.unique_union(documents)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user