This commit is contained in:
Eugene Yurtsev
2023-05-30 13:55:20 -04:00
parent c4b502a470
commit 65aad4c0bb
4 changed files with 807 additions and 0 deletions

View File

View File

@@ -0,0 +1,451 @@
import abc
import asyncio
import json
import urllib.parse
from bs4 import BeautifulSoup, PageElement
from typing import Sequence, List, Mapping, Any, Dict, Tuple, Optional
from langchain import PromptTemplate, LLMChain, serpapi
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
AsyncCallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.research.multiselection import (
IDParser,
_extract_content_from_tag,
MultiSelectChain,
)
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
from langchain.schema import BaseOutputParser, Document
from langchain.text_splitter import TextSplitter
Parser = MarkdownifyHTMLParser(tags_to_remove=("svg", "img", "script", "style", "a"))
class AHrefExtractor(BaseOutputParser[List[str]]):
"""An output parser that extracts all a-href links."""
def parse(self, text: str) -> List[str]:
return _extract_href_tags(text)
URL_CRAWLING_PROMPT = PromptTemplate.from_template(
"""\
Here is a list of URLs extracted from a page titled: `{title}`.
```csv
{urls}
```
---
Here is a question:
{question}
---
Please output the ids of the URLs that may contain content relevant to answer the question. \
Use only the information csv table of URLs to determine relevancy.
Format your answer inside of an <ids> tags, separating the ids by a comma.
For example, if the 132 and 133 URLs are relevant, you would write: <ids>132,133</ids>
Begin:""",
output_parser=IDParser(),
)
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
"""Get surrounding text the given tag in the given direction.
Args:
tag: the tag to get surrounding text for.
n: number of characters to get
is_before: Whether to get text before or after the tag.
Returns:
the surrounding text in the given direction.
"""
text = ""
current = tag.previous_element if is_before else tag.next_element
while current and len(text) < n:
current_text = str(current.text).strip()
current_text = (
current_text
if len(current_text) + len(text) <= n
else current_text[: n - len(text)]
)
if is_before:
text = current_text + " " + text
else:
text = text + " " + current_text
current = current.previous_element if is_before else current.next_element
return text
def get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
"""Get a list of <a> tags as snippets from the given html.
Args:
html: the html to get snippets from.
num_chars: the number of characters to get around the <a> tags.
Returns:
a list of snippets.
"""
soup = BeautifulSoup(html, "html.parser")
title = soup.title.string.strip()
snippets = []
for idx, a_tag in enumerate(soup.find_all("a")):
before_text = _get_surrounding_text(a_tag, num_chars, is_before=True)
after_text = _get_surrounding_text(a_tag, num_chars, is_before=False)
snippet = {
"id": idx,
"before": before_text.strip().replace("\n", " "),
"link": a_tag.get("href").replace("\n", " ").strip(),
"content": a_tag.text.replace("\n", " ").strip(),
"after": after_text.strip().replace("\n", " "),
}
snippets.append(snippet)
return {
"snippets": snippets,
"title": title,
}
def _extract_href_tags(html: str) -> List[str]:
"""Extract href tags.
Args:
html: the html to extract href tags from.
Returns:
a list of href tags.
"""
href_tags = []
soup = BeautifulSoup(html, "html.parser")
for a_tag in soup.find_all("a"):
href = a_tag.get("href")
if href:
href_tags.append(href)
return href_tags
class QueryExtractor(BaseOutputParser[List[str]]):
"""An output parser that extracts all queries."""
def parse(self, text: str) -> List[str]:
"""Extract all content of <query> from the text."""
return _extract_content_from_tag(text, "query")
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
"""\
Suggest a few different search queries that could be used to identify web-pages that could answer \
the following question.
If the question is about a named entity start by listing very general searches (e.g., just the named entity) \
and then suggest searches more scoped to the question.
Input: ```Where did John Snow from Cambridge, UK work?```
Output: <query>John Snow</query>
<query>John Snow Cambridge UK</query>
<query> John Snow Cambridge UK work history </query>
<query> John Snow Cambridge UK cv </query>
Input: ```How many research papers did Jane Doe publish in 2010?```
Output: <query>Jane Doe</query>
<query>Jane Doe research papers</query>
<query>Jane Doe research research</query>
<query>Jane Doe publications</query>
<query>Jane Doe publications 2010</query>
Input: ```What is the capital of France?```
Output: <query>France</query>
<query>France capital</query>
<query>France capital city</query>
<query>France capital city name</query>
Input: ```What are the symptoms of COVID-19?```
Output: <query>COVID-19</query>
<query>COVID-19 symptoms</query>
<query>COVID-19 symptoms list</query>
<query>COVID-19 symptoms list WHO</query>
Input: ```What is the revenue stream of CVS?```
Output: <query>CVS</query>
<query>CVS revenue</query>
<query>CVS revenue stream</query>
<query>CVS revenue stream business model</query>
Input: ```{question}```
Output:
""",
output_parser=QueryExtractor(),
)
def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
"""Generate queries using a Chain."""
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
queries = chain.predict_and_parse(question=question)
return queries
def _deduplicate_objects(
dicts: Sequence[Mapping[str, Any]], key: str
) -> List[Mapping[str, Any]]:
"""Deduplicate objects by the given key."""
unique_values = set()
deduped: List[Mapping[str, Any]] = []
for d in dicts:
value = d[key]
if value not in unique_values:
unique_values.add(value)
deduped.append(d)
return deduped
def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run the given queries and return the unique results.
Args:
queries: a list of queries to run
Returns:
a list of unique search results
"""
wrapper = serpapi.SerpAPIWrapper()
results = []
for query in queries:
result = wrapper.results(query)
organic_results = result["organic_results"]
results.extend(organic_results)
unique_results = _deduplicate_objects(results, "link")
return unique_results
class BlobCrawler(abc.ABC):
"""Crawl a blob and identify links to related content."""
@abc.abstractmethod
def crawl(self, blob: Blob, query: str) -> List[str]:
"""Explore the blob and identify links to related content that is relevant to the query."""
def _extract_records(blob: Blob) -> Tuple[List[Mapping[str, Any]], Tuple[str, ...]]:
"""Extract records from a blob."""
if blob.mimetype == "text/html":
info = get_ahref_snippets(blob.as_string(), num_chars=100)
return (
[
{
"content": d["content"],
"link": d["link"],
"before": d["before"],
"after": d["after"],
}
for d in info["snippets"]
],
("link", "content", "before", "after"),
)
elif blob.mimetype == "application/json": # Represent search results
data = json.loads(blob.as_string())
results = data["results"]
return [
{
"title": result["title"],
"snippet": result["snippet"],
"link": result["link"],
}
for result in results
], ("link", "title", "snippet")
else:
raise ValueError(
"Can only extract records from HTML/JSON blobs. Got {blob.mimetype}"
)
class ChainCrawler(BlobCrawler):
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
"""Crawl the blob using an LLM."""
self.chain = chain
self.parser = parser
def crawl(self, blob: Blob, question: str) -> List[str]:
"""Explore the blob and suggest additional content to explore."""
records, columns = _extract_records(blob)
result = self.chain(
inputs={"question": question, "choices": records, "columns": columns}
)
selected_records = result["selected"]
urls = [
# TODO(): handle absolute links
urllib.parse.urljoin(blob.source, record["link"])
for record in selected_records
if "mailto:" not in record["link"]
]
return urls
@classmethod
def from_default(
cls,
llm: BaseLanguageModel,
blob_parser: BaseBlobParser = MarkdownifyHTMLParser(),
) -> "ChainCrawler":
"""Create a crawler from the default LLM."""
chain = MultiSelectChain.from_default(llm)
return cls(chain=chain, parser=blob_parser)
class DocumentProcessor(abc.ABC):
def transform(self, document: Document) -> List[Document]:
"""Transform the document."""
raise NotImplementedError()
class ReadEntireDocChain(Chain):
"""Read entire document chain.
This chain implements a brute force approach to reading an entire document.
"""
chain: Chain
text_splitter: TextSplitter
max_num_docs: int = -1
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["doc", "question"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["document"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document."""
source_document = inputs["doc"]
if not isinstance(source_document, Document):
raise TypeError(f"Expected a Document, got {type(source_document)}")
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([source_document])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
response = self.chain(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),
)
summary_doc = Document(
page_content=response["output_text"],
metadata=source_document.metadata,
)
return {"document": summary_doc}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document."""
doc = inputs["doc"]
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([doc])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
results = await self.chain.acall(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),
)
summary_doc = Document(
page_content=results["output_text"],
metadata=doc.metadata,
)
return {"document": summary_doc}
class ParallelApply(Chain):
"""Apply a chain in parallel."""
chain: Chain
max_concurrency: int
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["inputs"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["results"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain."""
# TODO(): parallelize this
chain_inputs = inputs["inputs"]
results = [
self.chain(
chain_input,
callbacks=run_manager.get_child() if run_manager else None,
)
for chain_input in chain_inputs
]
return {"results": results}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain."""
chain_inputs = inputs["inputs"]
results = await asyncio.gather(
*[
self.chain.acall(
chain_input,
callbacks=run_manager.get_child() if run_manager else None,
)
for chain_input in chain_inputs
]
)
return {"results": results}

View File

@@ -0,0 +1,138 @@
"""Module contains code for fetching documents from the web using playwright.
This module currently re-uses the code from the `web_base` module to avoid
re-implementing rate limiting behavior.
The module contains downloading interfaces.
Sub-classing with the given interface should allow a user to add url based
user-agents and authentication if needed.
Downloading is batched by default to allow parallelizing efficiently.
"""
import abc
import asyncio
import mimetypes
from bs4 import BeautifulSoup
from typing import Sequence, List, Any
from langchain.document_loaders import WebBaseLoader
from langchain.document_loaders.blob_loaders import Blob
def _is_javascript_required(html_content: str) -> bool:
"""Heuristic to determine whether javascript execution is required.
Args:
html_content (str): The HTML content to check.
Returns:
bool: True if javascript execution is required, False otherwise.
"""
# Parse the HTML content using BeautifulSoup
soup = BeautifulSoup(html_content, "lxml")
# Count the number of HTML elements
body = soup.body
if not body:
return True
num_elements = len(body.find_all())
requires_javascript = num_elements < 1
return requires_javascript
class DownloadHandler(abc.ABC):
def download(self, urls: Sequence[str]) -> List[Blob]:
"""Download a url synchronously."""
raise NotImplementedError()
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
"""Download a url asynchronously."""
raise NotImplementedError()
class PlaywrightDownloadHandler(DownloadHandler):
"""Download URLS using playwright.
This is an implementation of the download handler that uses playwright to download
urls. This is useful for downloading urls that require javascript to be executed.
"""
def download(self, urls: Sequence[str]) -> List[Blob]:
"""Download list of urls synchronously."""
# Implement using a threadpool or using playwright API if it supports it
raise NotImplementedError()
async def _download(self, browser: Any, url: str) -> str:
"""Download a url asynchronously using playwright."""
page = await browser.new_page()
await page.goto(url, wait_until="networkidle")
html_content = await page.content()
return html_content
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
"""Download a url asynchronously using playwright.
Args:
url: The url to download.
Returns:
The html content of the url.
"""
from playwright.async_api import async_playwright
async with async_playwright() as p:
browser = await p.chromium.launch()
tasks = [self._download(browser, url) for url in urls]
contents = await asyncio.gather(*tasks, return_exceptions=True)
await browser.close()
return _repackage_as_blobs(urls, contents)
class RequestsDownloadHandler(DownloadHandler):
def __init__(self, web_downloader: WebBaseLoader):
self.web_downloader = web_downloader
def download(self, url: str) -> str:
"""Download a url synchronously."""
# Implement with threadpool.
raise NotImplementedError()
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
"""Download a batch of urls asynchronously using playwright."""
download = WebBaseLoader(web_path=[]) # Place holder
contents = await download.fetch_all(list(urls))
return _repackage_as_blobs(urls, contents)
def _repackage_as_blobs(urls: Sequence[str], contents: Sequence[str]) -> List[Blob]:
"""Repackage the contents as blobs."""
return [
Blob(data=content, mimetype=mimetypes.guess_type(url))
for url, content in zip(urls, contents)
]
class AutoDownloadHandler(DownloadHandler):
def __init__(self, web_downloader: WebBaseLoader) -> None:
"""Initialize the auto download handler."""
self.requests_downloader = RequestsDownloadHandler(web_downloader)
self.playwright_downloader = PlaywrightDownloadHandler()
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
"""Download a batch of urls asynchronously using playwright."""
# Check if javascript is required
blobs = await self.requests_downloader.adownload(urls)
# Check if javascript is required
must_redownload = [
(idx, url)
for idx, (url, blob) in enumerate(zip(urls, blobs))
if _is_javascript_required(blob.data)
]
indexes, urls_to_redownload = zip(*must_redownload)
contents = await self.playwright_downloader.adownload(urls)
return contents

View File

@@ -0,0 +1,218 @@
"""Perform classification / selection using language models."""
from __future__ import annotations
import csv
from bs4 import BeautifulSoup
from io import StringIO
from typing import Sequence, Mapping, Any, Optional, Dict, List, cast, Set
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseOutputParser
from typing import TypedDict
MULTI_SELECT_TEMPLATE = """\
Here is a table in CSV format:
{records}
---
question:
{question}
---
Output IDs of rows that answer the question or match the question.
For example, if row id 132 and id 133 are relevant, output: <ids>132,133</ids>
---
Begin:"""
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
"""Extract content from the given tag."""
soup = BeautifulSoup(html, "html.parser")
queries = []
for query in soup.find_all(tag):
queries.append(query.text)
return queries
class IDParser(BaseOutputParser[List[int]]):
"""An output parser that extracts all IDs from the output."""
def parse(self, text: str) -> List[int]:
"""Parse the text and return a list of IDs"""
tags = _extract_content_from_tag(text, "ids")
if not tags:
return []
if len(tags) > 1:
# Fail if more than 1 tag group is identified
return []
tag = tags[0]
ids = tag.split(",")
finalized_ids = []
for idx in ids:
if idx.isdigit():
finalized_ids.append(int(idx))
return finalized_ids
def _write_records_to_string(
records: Sequence[Mapping[str, Any]],
*,
columns: Optional[Sequence[str]] = None,
delimiter: str = "|",
) -> str:
"""Write records to a CSV string.
Args:
records: a list of records, assumes that all records have all keys
columns: a list of columns to include in the CSV
delimiter: the delimiter to use
Returns:
a CSV string
"""
buffer = StringIO()
if columns is None:
existing_columns: Set[str] = set()
for record in records:
existing_columns.update(record.keys())
_columns: Sequence[str] = sorted(existing_columns)
else:
_columns = columns
# Make sure the id column is always first
_columns_with_id_first = list(_columns)
if "id" in _columns_with_id_first:
_columns_with_id_first.remove("id")
# Make sure the `id` column is always first
_columns_with_id_first.insert(0, "id")
writer = csv.DictWriter(
buffer,
fieldnames=_columns_with_id_first,
delimiter=delimiter,
)
writer.writeheader()
writer.writerows(records)
buffer.seek(0)
return buffer.getvalue()
class MultiSelectionInput(TypedDict):
"""Input for the multi-selection chain."""
question: str
records: Sequence[Mapping[str, Any]]
delimiter: str
columns: Optional[Sequence[str]]
class MultiSelectionOutput(TypedDict):
"""Output for the multi-selection chain."""
records: Sequence[Mapping[str, Any]]
from itertools import islice
def batch(iterable, size):
iterator = iter(iterable)
while True:
batch = list(islice(iterator, size))
if not batch:
return
yield batch
class MultiSelectChain(Chain):
"""A chain that performs multi-selection from a list of choices."""
llm_chain: LLMChain
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question", "choices"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["selected"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> MultiSelectionOutput:
"""Run the chain."""
choices = inputs["choices"]
question = inputs["question"]
columns = inputs.get("columns", None)
selected = []
max_choices = 30
for choice_batch in batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
records_str = _write_records_to_string(
records_with_ids, columns=columns, delimiter="|"
)
indexes = cast(
List[int],
self.llm_chain.predict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child(),
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
selected.extend(choice_batch[i] for i in valid_indexes)
return {
"selected": selected,
}
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "multilabel_binary_classifier"
@classmethod
def from_default(
cls,
llm: BaseLanguageModel,
*,
prompt: str = MULTI_SELECT_TEMPLATE,
parser: BaseOutputParser = IDParser(),
) -> MultiSelectChain:
"""Provide a multilabel binary classifier."""
prompt_template = PromptTemplate.from_template(prompt, output_parser=parser)
if set(prompt_template.input_variables) != {"question", "records"}:
raise ValueError("Prompt must contain only {question} and {records}")
return cls(
llm_chain=LLMChain(
llm=llm,
prompt=prompt_template,
)
)