refactor web base loader (#11057)

This commit is contained in:
Bagatur 2023-09-26 08:11:31 -07:00 committed by GitHub
parent 487611521d
commit 097ecef06b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 81 deletions

View File

@ -32,13 +32,6 @@ class BlackboardLoader(WebBaseLoader):
""" # noqa: E501 """ # noqa: E501
base_url: str
"""Base url of the blackboard course."""
folder_path: str
"""Path to the folder containing the documents."""
load_all_recursively: bool
"""If True, load all documents recursively."""
def __init__( def __init__(
self, self,
blackboard_course_url: str, blackboard_course_url: str,
@ -46,7 +39,7 @@ class BlackboardLoader(WebBaseLoader):
load_all_recursively: bool = True, load_all_recursively: bool = True,
basic_auth: Optional[Tuple[str, str]] = None, basic_auth: Optional[Tuple[str, str]] = None,
cookies: Optional[dict] = None, cookies: Optional[dict] = None,
continue_on_failure: Optional[bool] = False, continue_on_failure: bool = False,
): ):
"""Initialize with blackboard course url. """Initialize with blackboard course url.
@ -66,7 +59,9 @@ class BlackboardLoader(WebBaseLoader):
Raises: Raises:
ValueError: If blackboard course url is invalid. ValueError: If blackboard course url is invalid.
""" """
super().__init__(blackboard_course_url) super().__init__(
web_paths=(blackboard_course_url), continue_on_failure=continue_on_failure
)
# Get base url # Get base url
try: try:
self.base_url = blackboard_course_url.split("/webapps/blackboard")[0] self.base_url = blackboard_course_url.split("/webapps/blackboard")[0]
@ -84,7 +79,6 @@ class BlackboardLoader(WebBaseLoader):
cookies.update({"BbRouter": bbrouter}) cookies.update({"BbRouter": bbrouter})
self.session.cookies.update(cookies) self.session.cookies.update(cookies)
self.load_all_recursively = load_all_recursively self.load_all_recursively = load_all_recursively
self.continue_on_failure = continue_on_failure
self.check_bs4() self.check_bs4()
def check_bs4(self) -> None: def check_bs4(self) -> None:

View File

@ -18,7 +18,7 @@ class GitbookLoader(WebBaseLoader):
load_all_paths: bool = False, load_all_paths: bool = False,
base_url: Optional[str] = None, base_url: Optional[str] = None,
content_selector: str = "main", content_selector: str = "main",
continue_on_failure: Optional[bool] = False, continue_on_failure: bool = False,
): ):
"""Initialize with web page and whether to load all paths. """Initialize with web page and whether to load all paths.
@ -41,13 +41,10 @@ class GitbookLoader(WebBaseLoader):
self.base_url = self.base_url[:-1] self.base_url = self.base_url[:-1]
if load_all_paths: if load_all_paths:
# set web_path to the sitemap if we want to crawl all paths # set web_path to the sitemap if we want to crawl all paths
web_paths = f"{self.base_url}/sitemap.xml" web_page = f"{self.base_url}/sitemap.xml"
else: super().__init__(web_paths=(web_page,), continue_on_failure=continue_on_failure)
web_paths = web_page
super().__init__(web_paths)
self.load_all_paths = load_all_paths self.load_all_paths = load_all_paths
self.content_selector = content_selector self.content_selector = content_selector
self.continue_on_failure = continue_on_failure
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Fetch text from one single GitBook page.""" """Fetch text from one single GitBook page."""

View File

@ -33,6 +33,7 @@ class SitemapLoader(WebBaseLoader):
meta_function: Optional[Callable] = None, meta_function: Optional[Callable] = None,
is_local: bool = False, is_local: bool = False,
continue_on_failure: bool = False, continue_on_failure: bool = False,
**kwargs: Any,
): ):
"""Initialize with webpage path and optional filter URLs. """Initialize with webpage path and optional filter URLs.
@ -67,7 +68,7 @@ class SitemapLoader(WebBaseLoader):
"lxml package not found, please install it with " "`pip install lxml`" "lxml package not found, please install it with " "`pip install lxml`"
) )
super().__init__(web_path) super().__init__(web_paths=[web_path], **kwargs)
self.filter_urls = filter_urls self.filter_urls = filter_urls
self.parsing_function = parsing_function or _default_parsing_function self.parsing_function = parsing_function or _default_parsing_function
@ -130,7 +131,7 @@ class SitemapLoader(WebBaseLoader):
fp = open(self.web_path) fp = open(self.web_path)
soup = bs4.BeautifulSoup(fp, "xml") soup = bs4.BeautifulSoup(fp, "xml")
else: else:
soup = self.scrape("xml") soup = self._scrape(self.web_path, parser="xml")
els = self.parse_sitemap(soup) els = self.parse_sitemap(soup)

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
import warnings import warnings
from typing import Any, Dict, Iterator, List, Optional, Union from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import aiohttp import aiohttp
import requests import requests
@ -39,71 +39,77 @@ def _build_metadata(soup: Any, url: str) -> dict:
class WebBaseLoader(BaseLoader): class WebBaseLoader(BaseLoader):
"""Load HTML pages using `urllib` and parse them with `BeautifulSoup'.""" """Load HTML pages using `urllib` and parse them with `BeautifulSoup'."""
web_paths: List[str]
requests_per_second: int = 2
"""Max number of concurrent requests to make."""
default_parser: str = "html.parser"
"""Default parser to use for BeautifulSoup."""
requests_kwargs: Dict[str, Any] = {}
"""kwargs for requests"""
raise_for_status: bool = False
"""Raise an exception if http status code denotes an error."""
bs_get_text_kwargs: Dict[str, Any] = {}
"""kwargs for beatifulsoup4 get_text"""
def __init__( def __init__(
self, self,
web_path: Union[str, List[str]], web_path: Union[str, Sequence[str]] = "",
header_template: Optional[dict] = None, header_template: Optional[dict] = None,
verify_ssl: Optional[bool] = True, verify_ssl: bool = True,
proxies: Optional[dict] = None, proxies: Optional[dict] = None,
continue_on_failure: Optional[bool] = False, continue_on_failure: bool = False,
autoset_encoding: Optional[bool] = True, autoset_encoding: bool = True,
encoding: Optional[str] = None, encoding: Optional[str] = None,
): web_paths: Sequence[str] = (),
"""Initialize with webpage path.""" requests_per_second: int = 2,
default_parser: str = "html.parser",
requests_kwargs: Optional[Dict[str, Any]] = None,
raise_for_status: bool = False,
bs_get_text_kwargs: Optional[Dict[str, Any]] = None,
bs_kwargs: Optional[Dict[str, Any]] = None,
session: Any = None,
) -> None:
"""Initialize loader.
# TODO: Deprecate web_path in favor of web_paths, and remove this Args:
# left like this because there are a number of loaders that expect single web_paths: Web paths to load from.
# urls requests_per_second: Max number of concurrent requests to make.
if isinstance(web_path, str): default_parser: Default parser to use for BeautifulSoup.
self.web_paths = [web_path] requests_kwargs: kwargs for requests
elif isinstance(web_path, List): raise_for_status: Raise an exception if http status code denotes an error.
self.web_paths = web_path bs_get_text_kwargs: kwargs for beatifulsoup4 get_text
bs_kwargs: kwargs for beatifulsoup4 web page parsing
try: """
import bs4 # noqa:F401 # web_path kept for backwards-compatibility.
except ImportError: if web_path and web_paths:
raise ImportError( raise ValueError(
"bs4 package not found, please install it with " "`pip install bs4`" "Received web_path and web_paths. Only one can be specified. "
"web_path is deprecated, web_paths should be used."
) )
if web_paths:
self.web_paths = list(web_paths)
elif isinstance(web_path, Sequence):
self.web_paths = list(web_path)
else:
self.web_paths = [web_path]
self.requests_per_second = requests_per_second
self.default_parser = default_parser
self.requests_kwargs = requests_kwargs or {}
self.raise_for_status = raise_for_status
self.bs_get_text_kwargs = bs_get_text_kwargs or {}
self.bs_kwargs = bs_kwargs or {}
if session:
self.session = session
else:
session = requests.Session()
header_template = header_template or default_header_template.copy()
if not header_template.get("User-Agent"):
try:
from fake_useragent import UserAgent
headers = header_template or default_header_template header_template["User-Agent"] = UserAgent().random
if not headers.get("User-Agent"): except ImportError:
try: logger.info(
from fake_useragent import UserAgent "fake_useragent not found, using default user agent."
"To get a realistic header for requests, "
headers["User-Agent"] = UserAgent().random "`pip install fake_useragent`."
except ImportError: )
logger.info( session.headers = dict(header_template)
"fake_useragent not found, using default user agent." session.verify = verify_ssl
"To get a realistic header for requests, " if proxies:
"`pip install fake_useragent`." session.proxies.update(proxies)
) self.session = session
self.session = requests.Session()
self.session.headers = dict(headers)
self.session.verify = verify_ssl
self.continue_on_failure = continue_on_failure self.continue_on_failure = continue_on_failure
self.autoset_encoding = autoset_encoding self.autoset_encoding = autoset_encoding
self.encoding = encoding self.encoding = encoding
if proxies:
self.session.proxies.update(proxies)
@property @property
def web_path(self) -> str: def web_path(self) -> str:
@ -193,11 +199,16 @@ class WebBaseLoader(BaseLoader):
else: else:
parser = self.default_parser parser = self.default_parser
self._check_parser(parser) self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser)) final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results return final_results
def _scrape(self, url: str, parser: Union[str, None] = None) -> Any: def _scrape(
self,
url: str,
parser: Union[str, None] = None,
bs_kwargs: Optional[dict] = None,
) -> Any:
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
if parser is None: if parser is None:
@ -216,7 +227,7 @@ class WebBaseLoader(BaseLoader):
html_doc.encoding = self.encoding html_doc.encoding = self.encoding
elif self.autoset_encoding: elif self.autoset_encoding:
html_doc.encoding = html_doc.apparent_encoding html_doc.encoding = html_doc.apparent_encoding
return BeautifulSoup(html_doc.text, parser) return BeautifulSoup(html_doc.text, parser, **(bs_kwargs or {}))
def scrape(self, parser: Union[str, None] = None) -> Any: def scrape(self, parser: Union[str, None] = None) -> Any:
"""Scrape data from webpage and return it in BeautifulSoup format.""" """Scrape data from webpage and return it in BeautifulSoup format."""
@ -224,7 +235,7 @@ class WebBaseLoader(BaseLoader):
if parser is None: if parser is None:
parser = self.default_parser parser = self.default_parser
return self._scrape(self.web_path, parser) return self._scrape(self.web_path, parser=parser, bs_kwargs=self.bs_kwargs)
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path.""" """Lazy load text from the url(s) in web_path."""
@ -243,10 +254,9 @@ class WebBaseLoader(BaseLoader):
results = self.scrape_all(self.web_paths) results = self.scrape_all(self.web_paths)
docs = [] docs = []
for i in range(len(results)): for path, soup in zip(self.web_paths, results):
soup = results[i]
text = soup.get_text(**self.bs_get_text_kwargs) text = soup.get_text(**self.bs_get_text_kwargs)
metadata = _build_metadata(soup, self.web_paths[i]) metadata = _build_metadata(soup, path)
docs.append(Document(page_content=text, metadata=metadata)) docs.append(Document(page_content=text, metadata=metadata))
return docs return docs

View File

@ -50,7 +50,7 @@ class YahooFinanceNewsTool(BaseTool):
return f"No news found for company that searched with {query} ticker." return f"No news found for company that searched with {query} ticker."
if not links: if not links:
return f"No news found for company that searched with {query} ticker." return f"No news found for company that searched with {query} ticker."
loader = WebBaseLoader(links) loader = WebBaseLoader(web_paths=links)
docs = loader.load() docs = loader.load()
result = self._format_results(docs, query) result = self._format_results(docs, query)
if not result: if not result: