This commit is contained in:
Eugene Yurtsev
2023-05-30 14:54:42 -04:00
parent 69cc8d7f73
commit 08b6c75743
3 changed files with 73 additions and 28 deletions

View File

@@ -4,7 +4,6 @@ The main idea behind the crawling module is to identify additional links
that are worth exploring to find more documents that may be relevant for being
able to answer the question correctly.
"""
import abc
import json
import urllib.parse
from bs4 import PageElement, BeautifulSoup
@@ -12,6 +11,7 @@ from typing import List, Dict, Any, Tuple
from langchain.base_language import BaseLanguageModel
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.chains.research.typedefs import BlobCrawler
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
@@ -82,14 +82,6 @@ def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]
)
class BlobCrawler(abc.ABC):
"""Crawl a blob and identify links to related content."""
@abc.abstractmethod
def crawl(self, blob: Blob, query: str) -> List[str]:
"""Explore the blob and identify links to related content that is relevant to the query."""
class ChainCrawler(BlobCrawler):
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
"""Crawl the blob using an LLM."""

View File

@@ -1,11 +1,12 @@
"""Module for initiating a set of searches relevant for answering the question."""
import abc
import asyncio
import itertools
from bs4 import BeautifulSoup
from typing import Sequence, List, Mapping, Any
from langchain import PromptTemplate, LLMChain, serpapi
from langchain.base_language import BaseLanguageModel
from langchain.chains.research.typedefs import AbstractQueryGenerator, AbstractSearcher
from langchain.schema import BaseOutputParser
@@ -118,7 +119,7 @@ def _deduplicate_objects(
return deduped
def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run the given queries and return the unique results.
Args:
@@ -133,31 +134,46 @@ def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
result = wrapper.results(query)
organic_results = result["organic_results"]
results.extend(organic_results)
return results
unique_results = _deduplicate_objects(results, "link")
return unique_results
async def _arun_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run the given queries and return the unique results."""
wrapper = serpapi.SerpAPIWrapper()
tasks = [wrapper.results(query) for query in queries]
results = await asyncio.gather(*tasks)
return list(
itertools.chain.from_iterable(result["organic_results"] for result in results)
)
# PUBLIC API
def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
"""Generate queries using a Chain."""
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
queries = chain.predict_and_parse(question=question)
return queries
class GenericQueryGenerator(AbstractQueryGenerator):
def __init__(self, llm: BaseLanguageModel) -> None:
"""Initialize the query generator."""
self.llm = llm
def generate_queries(self, question: str) -> List[str]:
"""Generate queries using a Chain."""
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=self.llm)
queries = chain.predict_and_parse(question=question)
return queries
class AbstractSearcher(abc.ABC):
@abc.abstractmethod
def search(self, query: str) -> List[Mapping[str, Any]]:
"""Run a search for the given query.
class GenericSearcher(AbstractSearcher):
def __init__(self, search_engine: str = "serp") -> None:
"""Initialize the searcher.
Args:
query: the query to run the search for.
search_engine: the search engine to use. Placeholder for future work.
"""
raise NotImplementedError()
if search_engine != "serp":
raise NotImplementedError("Only serp is supported at the moment.")
# def search_all(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
# """Run a search for all the given queries."""
# raise NotImplementedError
def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run a search for the given query."""
results = await _arun_searches(queries)
return _deduplicate_objects(results, "link")

View File

@@ -0,0 +1,37 @@
from typing import List, Sequence, Mapping, Any
import abc
from langchain.document_loaders.blob_loaders import Blob
class AbstractQueryGenerator(abc.ABC):
"""Abstract class for generating queries."""
@abc.abstractmethod
def generate_queries(self, question: str) -> List[str]:
"""Generate queries for the given question."""
raise NotImplementedError()
class AbstractSearcher(abc.ABC):
"""Abstract class for running searches."""
def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run a search for the given query.
Args:
queries: the query to run the search for.
Returns:
a list of search results.
"""
raise NotImplementedError()
class BlobCrawler(abc.ABC):
"""Crawl a blob and identify links to related content."""
@abc.abstractmethod
def crawl(self, blob: Blob, query: str) -> List[str]:
"""Explore the blob and identify links to related content that is relevant to the query."""