Retriever that can re-phase user inputs (#8026)

Simple retriever that applies an LLM between the user input and the
query pass the to retriever.

It can be used to pre-process the user input in any way.

The default prompt:

```
DEFAULT_QUERY_PROMPT = PromptTemplate(
    input_variables=["question"],
    template="""You are an assistant tasked with taking a natural languge query from a user
    and converting it into a query for a vectorstore. In this process, you strip out
    information that is not relevant for the retrieval task. Here is the user query: {question} """
)
```

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Lance Martin
2023-08-03 21:23:59 -07:00
committed by GitHub
parent 6c3573e7f6
commit d1b95db874
3 changed files with 311 additions and 0 deletions

View File

@@ -42,6 +42,7 @@ from langchain.retrievers.milvus import MilvusRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pubmed import PubMedRetriever
from langchain.retrievers.re_phraser import RePhraseQueryRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.svm import SVMRetriever
@@ -86,6 +87,7 @@ __all__ = [
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"RePhraseQueryRetriever",
"WebResearchRetriever",
"EnsembleRetriever",
]

View File

@@ -0,0 +1,87 @@
import logging
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseRetriever, Document
logger = logging.getLogger(__name__)
# Default template
DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \
query from a user and converting it into a query for a vectorstore. \
In this process, you strip out information that is not relevant for \
the retrieval task. Here is the user query: {question}"""
# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE)
class RePhraseQueryRetriever(BaseRetriever):
"""Given a user query, use an LLM to re-phrase it.
Then, retrieve docs for re-phrased query."""
retriever: BaseRetriever
llm_chain: LLMChain
@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLLM,
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
) -> "RePhraseQueryRetriever":
"""Initialize from llm using default template.
The prompt used here expects a single input: `question`
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
prompt: prompt template for query generation
Returns:
RePhraseQueryRetriever
"""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
retriever=retriever,
llm_chain=llm_chain,
)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevated documents given a user question.
Args:
query: user question
Returns:
Relevant documents for re-phrased question
"""
response = self.llm_chain(query, callbacks=run_manager.get_child())
re_phrased_question = response["text"]
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.get_relevant_documents(
re_phrased_question, callbacks=run_manager.get_child()
)
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError