This commit is contained in:
Eugene Yurtsev
2023-06-01 16:39:52 -04:00
parent 1f1db7c96a
commit f449d26083
2 changed files with 187 additions and 120 deletions

View File

@@ -3,13 +3,17 @@ from __future__ import annotations
import asyncio
import itertools
from bs4 import BeautifulSoup
from typing import Sequence, List, Mapping, Any, Optional
from typing import Sequence, List, Mapping, Any, Optional, Union, Dict
from langchain import PromptTemplate, LLMChain, serpapi
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.chains.research.typedefs import AbstractQueryGenerator, AbstractSearcher
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
@@ -47,7 +51,6 @@ class QueryExtractor(BaseOutputParser[List[str]]):
return _extract_content_from_tag(text, "query")
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
"""\
Suggest a few different search queries that could be used to identify web-pages \
@@ -121,11 +124,15 @@ def _deduplicate_objects(
return deduped
def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run the given queries and return the unique results.
def _run_searches(queries: Sequence[str], top_k: int = -1) -> List[Mapping[str, Any]]:
"""Run the given queries and return all the search results.
This function can return duplicated results, de-duplication can take place later
and take into account the frequency of appearance.
Args:
queries: a list of queries to run
top_k: the number of results to return, if -1 return all results
Returns:
a list of unique search results
@@ -134,26 +141,50 @@ def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
results = []
for query in queries:
result = wrapper.results(query)
organic_results = result["organic_results"]
if top_k <= 0:
organic_results = result["organic_results"]
else:
organic_results = result["organic_results"][:top_k]
results.extend(organic_results)
return results
async def _arun_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run the given queries and return the unique results."""
async def _arun_searches(
queries: Sequence[str], top_k: int = -1
) -> List[Mapping[str, Any]]:
"""Run the given queries and return all the search results.
This function can return duplicated results, de-duplication can take place later
and take into account the frequency of appearance.
Args:
queries: a list of queries to run
top_k: the number of results to return, if -1 return all results
Returns:
a list of unique search results
"""
wrapper = serpapi.SerpAPIWrapper()
tasks = [wrapper.results(query) for query in queries]
results = await asyncio.gather(*tasks)
return list(
itertools.chain.from_iterable(result["organic_results"] for result in results)
)
finalized_results = []
for result in results:
if top_k <= 0:
organic_results = result["organic_results"]
else:
organic_results = result["organic_results"][:top_k]
finalized_results.extend(organic_results)
return finalized_results
# PUBLIC API
class QueryGenerator(AbstractQueryGenerator):
def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
"""Use an LLM to break down a complex query into a list of simpler queries.
The simpler queries are used to run searches against a search engine in a goal
@@ -162,82 +193,118 @@ class QueryGenerator(AbstractQueryGenerator):
Query:
Does Harrison Chase of Langchain like to eat pizza or play squash?
Does Harrison Chase of Langchain like to eat pizza or play squash?
May be broken down into a list of simpler queries like:
- Harrison Chase
- Harrison Chase Langchain
- Harrison Chase Langchain pizza
- Harrison Chase Langchain squash
- Harrison Chase
- Harrison Chase Langchain
- Harrison Chase Langchain pizza
- Harrison Chase Langchain squash
"""
return LLMChain(
llm=llm,
output_keys=["urls"],
promp=QUERY_GENERATION_PROMPT,
)
class GenericSearcher(Chain):
"""A chain that takes a complex question and identifies a list of relevant urls.
The chain works by:
1. Breaking a complex question into a series of simpler queries using an LLM.
2. Running the queries against a search engine.
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf or other models).
This chain is not meant to be used for questions requiring multiple hops to answer.
For example, the age of leonardo dicaprio's girlfriend is a multi-hop question
This kind of question requires a slightly different approach.
This chain is meant to handle questions for which one wants
to collect information from multiple sources.
To extend implementation:
* Parameterize the search engine (to allow for non serp api search engines)
* Expose an abstract interface for query generator and link selection model
"""
def __init__(self, llm: BaseLanguageModel) -> None:
"""Initialize the query generator."""
self.chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
query_generator: LLMChain
"""An LLM that is used to break down a complex question into a list of simpler queries."""
link_selection_model: LLMChain
"""An LLM that is used to select the most relevant urls from the search results."""
top_k_per_search: int = -1
"""The number of top urls to select from each search."""
def generate_queries(self, question: str) -> List[str]:
"""Generate queries using a Chain."""
queries = self.chain.predict_and_parse(question=question)
return queries
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question"]
async def agenerate_queries(self, question: str) -> List[str]:
"""Generate queries using a Chain."""
queries = await self.chain.apredict_and_parse(question=question)
return queries
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["urls"]
class GenericSearcher(AbstractSearcher):
"""A generic searcher implementation.
Generic searcher is parameterized by the search engine to allow for
extending it to other search engines in the future.
"""
def __init__(
def _call(
self,
*,
search_engine: str = "serp",
link_selection_model: Optional[MultiSelectChain] = None,
) -> None:
"""Initialize the searcher.
Args:
search_engine: the search engine to use. Placeholder for future work.
link_selection_model: the link selection model to use.
"""
if search_engine != "serp":
raise NotImplementedError("Only serp is supported at the moment.")
self.link_selection_model = link_selection_model
async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run searchers for given sequence of queries.
Args:
queries: a list of queries to run
Returns:
a list of search results
"""
results = await _arun_searches(queries)
deduplicated_results = _deduplicate_objects(results, "link")
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
response = self.query_generator.predict_and_parse(
callbacks=run_manager.get_child(), question=question
)
queries = response["queries"]
results = _run_searches(queries, top_k=self.top_k_per_search)
deuped_results = _deduplicate_objects(results, "link")
records = [
{
"title": result["title"],
"snippet": result["snippet"],
"link": result["link"],
}
for result in deduplicated_results
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = self.link_selection_model.predict_and_parse(
callbacks=run_manager.get_child(),
question=question,
choices=records,
)
return {"urls": [result["link"] for result in response_["selected"]]}
return records
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
response = await self.query_generator.apredict_and_parse(
callbacks=run_manager.get_child(), question=question
)
queries = response["queries"]
results = await _arun_searches(queries)
deuped_results = _deduplicate_objects(results, "link")
records = [
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = self.link_selection_model.predict_and_parse(
callbacks=run_manager.get_child(),
question=question,
choices=records,
)
return {"urls": [result["link"] for result in response_["selected"]]}
@classmethod
def from_llm(
cls, link_selection_llm: BaseLanguageModel, *, search_engine: str = "serp"
cls,
link_selection_llm: BaseLanguageModel,
query_generation_llm: BaseLanguageModel,
) -> GenericSearcher:
"""Initialize the searcher from a language model."""
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
query_generation_model = make_query_generator(llm=query_generation_llm)
return cls(
search_engine=search_engine, link_selection_model=link_selection_model
link_selection_model=link_selection_model,
query_generator=query_generation_model,
)

View File

@@ -5,52 +5,52 @@ import abc
from langchain.callbacks.manager import Callbacks
from langchain.document_loaders.blob_loaders import Blob
class AbstractQueryGenerator(abc.ABC):
"""Abstract class for generating queries."""
@abc.abstractmethod
def generate_queries(self, question: str, callbacks: Callbacks = None) -> List[str]:
"""Generate queries for the given question."""
raise NotImplementedError()
@abc.abstractmethod
async def agenerate_queries(
self, question: str, callbacks: Callbacks = None
) -> List[str]:
"""Generate queries for the given question."""
raise NotImplementedError()
class AbstractSearcher(abc.ABC):
"""Abstract class for running searches."""
def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run a search for the given query.
Args:
queries: the query to run the search for.
Returns:
a list of search results.
"""
raise NotImplementedError()
async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run a search for the given query.
Args:
queries: the query to run the search for.
Returns:
a list of search results.
"""
raise NotImplementedError()
class BlobCrawler(abc.ABC):
"""Crawl a blob and identify links to related content."""
@abc.abstractmethod
def crawl(self, blob: Blob, query: str, callbacks: Callbacks = None) -> List[str]:
"""Explore the blob and identify links to related content that is relevant to the query."""
#
# class AbstractQueryGenerator(abc.ABC):
# """Abstract class for generating queries."""
#
# @abc.abstractmethod
# def generate_queries(self, question: str, callbacks: Callbacks = None) -> List[str]:
# """Generate queries for the given question."""
# raise NotImplementedError()
#
# @abc.abstractmethod
# async def agenerate_queries(
# self, question: str, callbacks: Callbacks = None
# ) -> List[str]:
# """Generate queries for the given question."""
# raise NotImplementedError()
#
#
# class AbstractSearcher(abc.ABC):
# """Abstract class for running searches."""
#
# def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
# """Run a search for the given query.
#
# Args:
# queries: the query to run the search for.
#
# Returns:
# a list of search results.
# """
# raise NotImplementedError()
#
# async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
# """Run a search for the given query.
#
# Args:
# queries: the query to run the search for.
#
# Returns:
# a list of search results.
# """
# raise NotImplementedError()
#
#
# class BlobCrawler(abc.ABC):
# """Crawl a blob and identify links to related content."""
#
# @abc.abstractmethod
# def crawl(self, blob: Blob, query: str, callbacks: Callbacks = None) -> List[str]:
# """Explore the blob and identify links to related content that is relevant to the query."""