community[minor]: Add alazy_load to AsyncHtmlLoader (#21536)

Also fixes a bug that `_scrape` was called and was doing a second HTTP
request synchronously.

**Twitter handle:** cbornet_
This commit is contained in:
Christophe Bornet 2024-05-13 09:01:03 -07:00 committed by GitHub
parent 4c48732f94
commit e6fa4547b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,18 @@
import asyncio import asyncio
import logging import logging
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Dict, Iterator, List, Optional, Union, cast from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)
import aiohttp import aiohttp
import requests import requests
@ -52,6 +62,8 @@ class AsyncHtmlLoader(BaseLoader):
requests_kwargs: Optional[Dict[str, Any]] = None, requests_kwargs: Optional[Dict[str, Any]] = None,
raise_for_status: bool = False, raise_for_status: bool = False,
ignore_load_errors: bool = False, ignore_load_errors: bool = False,
*,
preserve_order: bool = True,
): ):
"""Initialize with a webpage path.""" """Initialize with a webpage path."""
@ -90,6 +102,7 @@ class AsyncHtmlLoader(BaseLoader):
self.autoset_encoding = autoset_encoding self.autoset_encoding = autoset_encoding
self.encoding = encoding self.encoding = encoding
self.ignore_load_errors = ignore_load_errors self.ignore_load_errors = ignore_load_errors
self.preserve_order = preserve_order
def _fetch_valid_connection_docs(self, url: str) -> Any: def _fetch_valid_connection_docs(self, url: str) -> Any:
if self.ignore_load_errors: if self.ignore_load_errors:
@ -110,35 +123,6 @@ class AsyncHtmlLoader(BaseLoader):
"`parser` must be one of " + ", ".join(valid_parsers) + "." "`parser` must be one of " + ", ".join(valid_parsers) + "."
) )
def _scrape(
self,
url: str,
parser: Union[str, None] = None,
bs_kwargs: Optional[dict] = None,
) -> Any:
from bs4 import BeautifulSoup
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
html_doc = self._fetch_valid_connection_docs(url)
if not getattr(html_doc, "ok", False):
return None
if self.raise_for_status:
html_doc.raise_for_status()
if self.encoding is not None:
html_doc.encoding = self.encoding
elif self.autoset_encoding:
html_doc.encoding = html_doc.apparent_encoding
return BeautifulSoup(html_doc.text, parser, **(bs_kwargs or {}))
async def _fetch( async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str: ) -> str:
@ -172,51 +156,79 @@ class AsyncHtmlLoader(BaseLoader):
async def _fetch_with_rate_limit( async def _fetch_with_rate_limit(
self, url: str, semaphore: asyncio.Semaphore self, url: str, semaphore: asyncio.Semaphore
) -> str: ) -> Tuple[str, str]:
async with semaphore: async with semaphore:
return await self._fetch(url) return url, await self._fetch(url)
async def fetch_all(self, urls: List[str]) -> Any: async def _lazy_fetch_all(
"""Fetch all urls concurrently with rate limiting.""" self, urls: List[str], preserve_order: bool
) -> AsyncIterator[Tuple[str, str]]:
semaphore = asyncio.Semaphore(self.requests_per_second) semaphore = asyncio.Semaphore(self.requests_per_second)
tasks = [] tasks = [
for url in urls: asyncio.create_task(self._fetch_with_rate_limit(url, semaphore))
task = asyncio.ensure_future(self._fetch_with_rate_limit(url, semaphore)) for url in urls
tasks.append(task) ]
try: try:
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
return await tqdm_asyncio.gather( if preserve_order:
*tasks, desc="Fetching pages", ascii=True, mininterval=1 for task in tqdm_asyncio(
) tasks, desc="Fetching pages", ascii=True, mininterval=1
):
yield await task
else:
for task in tqdm_asyncio.as_completed(
tasks, desc="Fetching pages", ascii=True, mininterval=1
):
yield await task
except ImportError: except ImportError:
warnings.warn("For better logging of progress, `pip install tqdm`") warnings.warn("For better logging of progress, `pip install tqdm`")
return await asyncio.gather(*tasks) if preserve_order:
for result in await asyncio.gather(*tasks):
yield result
else:
for task in asyncio.as_completed(tasks):
yield await task
async def fetch_all(self, urls: List[str]) -> List[str]:
"""Fetch all urls concurrently with rate limiting."""
return [doc async for _, doc in self._lazy_fetch_all(urls, True)]
def _to_document(self, url: str, text: str) -> Document:
from bs4 import BeautifulSoup
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
soup = BeautifulSoup(text, parser)
metadata = _build_metadata(soup, url)
return Document(page_content=text, metadata=metadata)
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path.""" """Lazy load text from the url(s) in web_path."""
for doc in self.load(): results: List[str]
yield doc
def load(self) -> List[Document]:
"""Load text from the url(s) in web_path."""
try: try:
# Raises RuntimeError if there is no current event loop. # Raises RuntimeError if there is no current event loop.
asyncio.get_running_loop() asyncio.get_running_loop()
# If there is a current event loop, we need to run the async code # If there is a current event loop, we need to run the async code
# in a separate loop, in a separate thread. # in a separate loop, in a separate thread.
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(asyncio.run, self.fetch_all(self.web_paths)) future: Future[List[str]] = executor.submit(
asyncio.run, # type: ignore[arg-type]
self.fetch_all(self.web_paths), # type: ignore[arg-type]
)
results = future.result() results = future.result()
except RuntimeError: except RuntimeError:
results = asyncio.run(self.fetch_all(self.web_paths)) results = asyncio.run(self.fetch_all(self.web_paths))
docs = []
for i, text in enumerate(cast(List[str], results)):
soup = self._scrape(self.web_paths[i])
if not soup:
continue
metadata = _build_metadata(soup, self.web_paths[i])
docs.append(Document(page_content=text, metadata=metadata))
return docs for i, text in enumerate(cast(List[str], results)):
yield self._to_document(self.web_paths[i], text)
async def alazy_load(self) -> AsyncIterator[Document]:
"""Lazy load text from the url(s) in web_path."""
async for url, text in self._lazy_fetch_all(
self.web_paths, self.preserve_order
):
yield self._to_document(url, text)