mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
Keep also original query - multi_query.py (#12696)
When you use a MultiQuery it might be useful to use the original query as well as the newly generated ones to maximise the changes to retriever the correct document. I haven't created an issue, it seems a very small and easy thing. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4fe9bf70b6
commit
2e2b9c76d9
@ -61,6 +61,8 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
verbose: bool = True
|
verbose: bool = True
|
||||||
parser_key: str = "lines"
|
parser_key: str = "lines"
|
||||||
|
include_original: bool = False
|
||||||
|
"""Whether to include the original query in the list of generated queries."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@ -69,12 +71,15 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
|
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||||
parser_key: str = "lines",
|
parser_key: str = "lines",
|
||||||
|
include_original: bool = False,
|
||||||
) -> "MultiQueryRetriever":
|
) -> "MultiQueryRetriever":
|
||||||
"""Initialize from llm using default template.
|
"""Initialize from llm using default template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
retriever: retriever to query documents from
|
retriever: retriever to query documents from
|
||||||
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
||||||
|
include_original: Whether to include the original query in the list of
|
||||||
|
generated queries.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MultiQueryRetriever
|
MultiQueryRetriever
|
||||||
@ -85,6 +90,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
parser_key=parser_key,
|
parser_key=parser_key,
|
||||||
|
include_original=include_original,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
@ -102,6 +108,8 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
Unique union of relevant documents from all generated queries
|
Unique union of relevant documents from all generated queries
|
||||||
"""
|
"""
|
||||||
queries = await self.agenerate_queries(query, run_manager)
|
queries = await self.agenerate_queries(query, run_manager)
|
||||||
|
if self.include_original:
|
||||||
|
queries.append(query)
|
||||||
documents = await self.aretrieve_documents(queries, run_manager)
|
documents = await self.aretrieve_documents(queries, run_manager)
|
||||||
return self.unique_union(documents)
|
return self.unique_union(documents)
|
||||||
|
|
||||||
@ -160,6 +168,8 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
Unique union of relevant documents from all generated queries
|
Unique union of relevant documents from all generated queries
|
||||||
"""
|
"""
|
||||||
queries = self.generate_queries(query, run_manager)
|
queries = self.generate_queries(query, run_manager)
|
||||||
|
if self.include_original:
|
||||||
|
queries.append(query)
|
||||||
documents = self.retrieve_documents(queries, run_manager)
|
documents = self.retrieve_documents(queries, run_manager)
|
||||||
return self.unique_union(documents)
|
return self.unique_union(documents)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user