This commit is contained in:
Eugene Yurtsev
2023-06-02 14:08:22 -04:00
parent 2b42b9cb82
commit 307df3ebed
2 changed files with 39 additions and 23 deletions

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
import asyncio
import itertools
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import itertools
from langchain.chains.research.download import AutoDownloadHandler, DownloadHandler
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@@ -11,7 +12,6 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.research.fetch import AutoDownloadHandler, DownloadHandler
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
from langchain.chains.research.search import GenericSearcher
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
@@ -109,7 +109,9 @@ class Research(Chain):
urls = search_results["urls"]
blobs = await self.downloader.adownload(urls)
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
docs = itertools.chain.from_iterable(
parser.lazy_parse(blob) for blob in blobs if blob is not None
)
inputs = [{"doc": doc, "question": question} for doc in docs]
results = await self.reader.acall(
inputs,

View File

@@ -14,14 +14,15 @@ Downloading is batched by default to allow efficient parallelization.
import abc
import asyncio
import mimetypes
from typing import Any, List, Optional, Sequence
from pydantic import ValidationError
from typing import Any, List, Sequence, Optional
from bs4 import BeautifulSoup
from langchain.document_loaders import WebBaseLoader
from langchain.document_loaders.blob_loaders import Blob
MaybeBlob = Optional[Blob]
def _is_javascript_required(html_content: str) -> bool:
"""Heuristic to determine whether javascript execution is required.
@@ -45,11 +46,11 @@ def _is_javascript_required(html_content: str) -> bool:
class DownloadHandler(abc.ABC):
def download(self, urls: Sequence[str]) -> List[Blob]:
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs synchronously."""
raise NotImplementedError()
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs asynchronously."""
raise NotImplementedError()
@@ -61,15 +62,29 @@ class PlaywrightDownloadHandler(DownloadHandler):
urls. This is useful for downloading urls that require javascript to be executed.
"""
def __init__(self, timeout: int = 5) -> None:
"""Initialize the download handler.
Args:
timeout: The timeout in seconds to wait for a page to load.
"""
self.timeout = timeout
def download(self, urls: Sequence[str]) -> List[Blob]:
"""Download list of urls synchronously."""
return asyncio.run(self.adownload(urls))
async def _download(self, browser: Any, url: str) -> str:
async def _download(self, browser: Any, url: str) -> Optional[str]:
"""Download a url asynchronously using playwright."""
from playwright.async_api import TimeoutError
page = await browser.new_page()
await page.goto(url, wait_until="networkidle")
html_content = await page.content()
try:
# Up to 5 seconds to load the page.
await page.goto(url, wait_until="networkidle")
html_content = await page.content()
except TimeoutError:
html_content = None
return html_content
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
@@ -101,25 +116,24 @@ class RequestsDownloadHandler(DownloadHandler):
"""Download a batch of URLS synchronously."""
return asyncio.run(self.adownload(urls))
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of urls asynchronously using playwright."""
download = WebBaseLoader(web_path=[]) # Place holder
contents = await download.fetch_all(list(urls))
return _repackage_as_blobs(urls, contents)
def _repackage_as_blobs(urls: Sequence[str], contents: Sequence[str]) -> List[Blob]:
def _repackage_as_blobs(
urls: Sequence[str], contents: Sequence[Optional[str]]
) -> List[MaybeBlob]:
"""Repackage the contents as blobs."""
blobs = []
blobs: List[MaybeBlob] = []
for url, content in zip(urls, contents):
mimetype = mimetypes.guess_type(url)[0]
try:
blobs.append(Blob(data=content, mimetype=mimetype, path=url))
except ValidationError:
raise ValueError(
f"Could not create a blob for content at {url}. "
f"Content type is {type(content)}"
)
if content is None:
blobs.append(None)
else:
blobs.append(Blob(data=content or "", mimetype=mimetype, path=url))
return blobs
@@ -137,7 +151,7 @@ class AutoDownloadHandler(DownloadHandler):
)
self.playwright_downloader = PlaywrightDownloadHandler()
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of urls asynchronously using playwright."""
# Check if javascript is required
blobs = await self.requests_downloader.adownload(urls)
@@ -156,6 +170,6 @@ class AutoDownloadHandler(DownloadHandler):
blobs[idx] = blob
return blobs
def download(self, urls: Sequence[str]) -> List[Blob]:
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs synchronously."""
return asyncio.run(self.adownload(urls))