Web research retriever (#8102)

Given a user question, this will -
* Use LLM to generate a set of queries.
* Query for each.
* The URLs from search results are stored in self.urls.
* A check is performed for any new URLs that haven't been processed yet
(not in self.url_database).
* Only these new URLs are loaded, transformed, and added to the
vectorstore.
* The vectorstore is queried for relevant documents based on the
questions generated by the LLM.
* Only unique documents are returned as the final result.

This code will avoid reprocessing of URLs across multiple runs of
similar queries, which should improve the performance of the retriever.
It also keeps track of all URLs that have been processed, which could be
useful for debugging or understanding the retriever's behavior.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Lance Martin
2023-07-25 19:58:00 -07:00
committed by GitHub
parent d1d691caa4
commit 7a00f17033
4 changed files with 803 additions and 2 deletions

View File

@@ -1,4 +1,3 @@
"""Web base loader class."""
import asyncio
import logging
import warnings
@@ -86,7 +85,12 @@ class AsyncHtmlLoader(BaseLoader):
headers=self.session.headers,
ssl=None if self.session.verify else False,
) as response:
return await response.text()
try:
text = await response.text()
except UnicodeDecodeError:
logger.error(f"Failed to decode content from {url}")
text = ""
return text
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise

View File

@@ -31,6 +31,7 @@ from langchain.retrievers.time_weighted_retriever import (
)
from langchain.retrievers.vespa_retriever import VespaRetriever
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
from langchain.retrievers.web_research import WebResearchRetriever
from langchain.retrievers.wikipedia import WikipediaRetriever
from langchain.retrievers.zep import ZepRetriever
from langchain.retrievers.zilliz import ZillizRetriever
@@ -65,5 +66,6 @@ __all__ = [
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"WebResearchRetriever",
"EnsembleRetriever",
]

View File

@@ -0,0 +1,215 @@
import logging
import re
from typing import List, Optional
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain.document_loaders import AsyncHtmlLoader
from langchain.document_transformers import Html2TextTransformer
from langchain.llms import LlamaCpp
from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.vectorstores.base import VectorStore
logger = logging.getLogger(__name__)
class SearchQueries(BaseModel):
"""Search queries to run to research for the user's goal."""
queries: List[str] = Field(
..., description="List of search queries to look up on Google"
)
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""<<SYS>> \n You are an assistant tasked with improving Google search
results. \n <</SYS>> \n\n [INST] Generate FIVE Google search queries that
are similar to this question. The output should be a numbered list of questions
and each should have a question mark at the end: \n\n {question} [/INST]""",
)
DEFAULT_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an assistant tasked with improving Google search
results. Generate FIVE Google search queries that are similar to
this question. The output should be a numbered list of questions and each
should have a question mark at the end: {question}""",
)
class LineList(BaseModel):
"""List of questions."""
lines: List[str] = Field(description="Questions")
class QuestionListOutputParser(PydanticOutputParser):
"""Output parser for a list of numbered questions."""
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)
def parse(self, text: str) -> LineList:
lines = re.findall(r"\d+\..*?\n", text)
return LineList(lines=lines)
class WebResearchRetriever(BaseRetriever):
# Inputs
vectorstore: VectorStore = Field(
..., description="Vector store for storing web pages"
)
llm_chain: LLMChain
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
max_splits_per_doc: int = Field(100, description="Maximum splits per document")
num_search_results: int = Field(1, description="Number of pages per Google search")
text_splitter: RecursiveCharacterTextSplitter = Field(
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
description="Text splitter for splitting web pages into chunks",
)
url_database: List[str] = Field(
default_factory=list, description="List of processed URLs"
)
@classmethod
def from_llm(
cls,
vectorstore: VectorStore,
llm: BaseLLM,
search: GoogleSearchAPIWrapper,
prompt: Optional[BasePromptTemplate] = None,
max_splits_per_doc: int = 100,
num_search_results: int = 1,
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=50
),
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
Args:
vectorstore: Vector store for storing web pages
llm: llm for search question generation
search: GoogleSearchAPIWrapper
prompt: prompt to generating search questions
max_splits_per_doc: Maximum splits per document to keep
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
Returns:
WebResearchRetriever
"""
if not prompt:
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_SEARCH_PROMPT,
conditionals=[
(lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT)
],
)
prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm)
# Use chat model prompt
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=QuestionListOutputParser(),
)
return cls(
vectorstore=vectorstore,
llm_chain=llm_chain,
search=search,
max_splits_per_doc=max_splits_per_doc,
num_search_results=num_search_results,
text_splitter=text_splitter,
)
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
"""Returns num_serch_results pages per Google search."""
result = self.search.results(query, num_search_results)
return result
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Search Google for documents related to the query input.
Args:
query: user query
Returns:
Relevant documents from all various urls.
"""
# Get search questions
logger.info("Generating questions for Google Search ...")
result = self.llm_chain({"question": query})
logger.info(f"Questions for Google Search (raw): {result}")
questions = getattr(result["text"], "lines", [])
logger.info(f"Questions for Google Search: {questions}")
# Get urls
logger.info("Searching for relevat urls ...")
urls_to_look = []
for query in questions:
# Google search
search_results = self.search_tool(query, self.num_search_results)
logger.info("Searching for relevat urls ...")
logger.info(f"Search results: {search_results}")
for res in search_results:
urls_to_look.append(res["link"])
# Relevant urls
urls = set(urls_to_look)
# Check for any new urls that we have not processed
new_urls = list(urls.difference(self.url_database))
logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore
if new_urls:
loader = AsyncHtmlLoader(new_urls)
html2text = Html2TextTransformer()
logger.info("Indexing new urls...")
docs = loader.load()
docs = list(html2text.transform_documents(docs))
docs = self.text_splitter.split_documents(docs)
self.vectorstore.add_documents(docs)
self.url_database.extend(new_urls)
# Search for relevant splits
# TODO: make this async
logger.info("Grabbing most relevant splits from urls...")
docs = []
for query in questions:
docs.extend(self.vectorstore.similarity_search(query))
# Get unique docs
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError