Compare commits

...

4 Commits

Author SHA1 Message Date
Bagatur
f2e51266cb fmt 2023-07-24 16:39:36 -07:00
Bagatur
e656e8cb8b merge 2023-07-24 16:38:57 -07:00
Harrison Chase
7c3ce368d7 Merge branch 'master' into harrison/async-web 2023-07-20 15:25:14 -07:00
Harrison Chase
56a321ba81 stash 2023-07-20 09:32:19 -07:00
2 changed files with 139 additions and 0 deletions

View File

@@ -140,6 +140,7 @@ from langchain.document_loaders.url_playwright import PlaywrightURLLoader
from langchain.document_loaders.url_selenium import SeleniumURLLoader
from langchain.document_loaders.weather import WeatherDataLoader
from langchain.document_loaders.web_base import WebBaseLoader
from langchain.document_loaders.web_loader import AsyncRawWebLoader
from langchain.document_loaders.whatsapp_chat import WhatsAppChatLoader
from langchain.document_loaders.wikipedia import WikipediaLoader
from langchain.document_loaders.word_document import (
@@ -301,4 +302,5 @@ __all__ = [
"XorbitsLoader",
"YoutubeAudioLoader",
"YoutubeLoader",
"AsyncRawWebLoader",
]

View File

@@ -0,0 +1,137 @@
"""Web base loader class."""
import asyncio
import logging
import warnings
from typing import Any, Dict, Iterator, List, Optional, Union
import aiohttp
import requests
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
default_header_template = {
"User-Agent": "",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*"
";q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Referer": "https://www.google.com/",
"DNT": "1",
"Connection": "keep-alive",
"Upgrade-Insecure-Requests": "1",
}
class AsyncRawWebLoader(BaseLoader):
"""Loader that loads all HTML asynchronously and returns the raw HTML."""
web_paths: List[str]
requests_per_second: int = 2
"""Max number of concurrent requests to make."""
requests_kwargs: Dict[str, Any] = {}
"""kwargs for requests"""
raise_for_status: bool = False
"""Raise an exception if http status code denotes an error."""
def __init__(
self,
web_path: Union[str, List[str]],
header_template: Optional[dict] = None,
verify_ssl: Optional[bool] = True,
proxies: Optional[dict] = None,
):
"""Initialize with webpage path."""
# TODO: Deprecate web_path in favor of web_paths, and remove this
# left like this because there are a number of loaders that expect single
# urls
if isinstance(web_path, str):
self.web_paths = [web_path]
elif isinstance(web_path, List):
self.web_paths = web_path
headers = header_template or default_header_template
if not headers.get("User-Agent"):
try:
from fake_useragent import UserAgent
headers["User-Agent"] = UserAgent().random
except ImportError:
logger.info(
"fake_useragent not found, using default user agent."
"To get a realistic header for requests, "
"`pip install fake_useragent`."
)
self.session = requests.Session()
self.session.headers = dict(headers)
self.session.verify = verify_ssl
if proxies:
self.session.proxies.update(proxies)
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession() as session:
for i in range(retries):
try:
async with session.get(
url,
headers=self.session.headers,
ssl=None if self.session.verify else False,
) as response:
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
logger.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
async def _fetch_with_rate_limit(
self, url: str, semaphore: asyncio.Semaphore
) -> str:
async with semaphore:
return await self._fetch(url)
async def fetch_all(self, urls: List[str]) -> Any:
"""Fetch all urls concurrently with rate limiting."""
semaphore = asyncio.Semaphore(self.requests_per_second)
tasks = []
for url in urls:
task = asyncio.ensure_future(self._fetch_with_rate_limit(url, semaphore))
tasks.append(task)
try:
from tqdm.asyncio import tqdm_asyncio
return await tqdm_asyncio.gather(
*tasks, desc="Fetching pages", ascii=True, mininterval=1
)
except ImportError:
warnings.warn("For better logging of progress, `pip install tqdm`")
return await asyncio.gather(*tasks)
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path."""
for doc in self.load():
yield doc
def load(self) -> List[Document]:
"""Load text from the url(s) in web_path."""
results = asyncio.run(self.fetch_all(self.web_paths))
docs = []
for i, text in enumerate(results):
metadata = {"source": self.web_paths[i]}
docs.append(Document(page_content=text, metadata=metadata))
return docs