mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
fix recursive loader (#10752)
maintain same base url throughout recursion, yield initial page, fixing recursion depth tracking
This commit is contained in:
parent
276125a33b
commit
96a9c27116
@ -1,12 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Iterator, List, Optional, Set, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.utils.html import extract_sub_links
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _metadata_extractor(raw_html: str, url: str) -> dict:
|
||||
"""Extract metadata from raw html using BeautifulSoup."""
|
||||
metadata = {"source": url}
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"The bs4 package is required for default metadata extraction. "
|
||||
"Please install it with `pip install bs4`."
|
||||
)
|
||||
return metadata
|
||||
soup = BeautifulSoup(raw_html, "html.parser")
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", None)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", None)
|
||||
return metadata
|
||||
|
||||
|
||||
class RecursiveUrlLoader(BaseLoader):
|
||||
@ -15,173 +54,106 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
max_depth: Optional[int] = None,
|
||||
max_depth: Optional[int] = 2,
|
||||
use_async: Optional[bool] = None,
|
||||
extractor: Optional[Callable[[str], str]] = None,
|
||||
exclude_dirs: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
prevent_outside: Optional[bool] = None,
|
||||
metadata_extractor: Optional[Callable[[str, str], str]] = None,
|
||||
exclude_dirs: Optional[Sequence[str]] = (),
|
||||
timeout: Optional[int] = 10,
|
||||
prevent_outside: Optional[bool] = True,
|
||||
link_regex: Union[str, re.Pattern, None] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Initialize with URL to crawl and any subdirectories to exclude.
|
||||
Args:
|
||||
url: The URL to crawl.
|
||||
exclude_dirs: A list of subdirectories to exclude.
|
||||
use_async: Whether to use asynchronous loading,
|
||||
if use_async is true, this function will not be lazy,
|
||||
but it will still work in the expected way, just not lazy.
|
||||
extractor: A function to extract the text from the html,
|
||||
when extract function returns empty string, the document will be ignored.
|
||||
max_depth: The max depth of the recursive loading.
|
||||
timeout: The timeout for the requests, in the unit of seconds.
|
||||
use_async: Whether to use asynchronous loading.
|
||||
If True, this function will not be lazy, but it will still work in the
|
||||
expected way, just not lazy.
|
||||
extractor: A function to extract document contents from raw html.
|
||||
When extract function returns an empty string, the document is
|
||||
ignored.
|
||||
metadata_extractor: A function to extract metadata from raw html and the
|
||||
source url (args in that order). Default extractor will attempt
|
||||
to use BeautifulSoup4 to extract the title, description and language
|
||||
of the page.
|
||||
exclude_dirs: A list of subdirectories to exclude.
|
||||
timeout: The timeout for the requests, in the unit of seconds. If None then
|
||||
connection will not timeout.
|
||||
prevent_outside: If True, prevent loading from urls which are not children
|
||||
of the root url.
|
||||
link_regex: Regex for extracting sub-links from the raw html of a web page.
|
||||
"""
|
||||
|
||||
self.url = url
|
||||
self.exclude_dirs = exclude_dirs
|
||||
self.max_depth = max_depth if max_depth is not None else 2
|
||||
self.use_async = use_async if use_async is not None else False
|
||||
self.extractor = extractor if extractor is not None else lambda x: x
|
||||
self.max_depth = max_depth if max_depth is not None else 2
|
||||
self.timeout = timeout if timeout is not None else 10
|
||||
self.metadata_extractor = (
|
||||
metadata_extractor
|
||||
if metadata_extractor is not None
|
||||
else _metadata_extractor
|
||||
)
|
||||
self.exclude_dirs = exclude_dirs if exclude_dirs is not None else ()
|
||||
self.timeout = timeout
|
||||
self.prevent_outside = prevent_outside if prevent_outside is not None else True
|
||||
|
||||
def _get_sub_links(self, raw_html: str, base_url: str) -> List[str]:
|
||||
"""This function extracts all the links from the raw html,
|
||||
and convert them into absolute paths.
|
||||
|
||||
Args:
|
||||
raw_html (str): original html
|
||||
base_url (str): the base url of the html
|
||||
|
||||
Returns:
|
||||
List[str]: sub links
|
||||
"""
|
||||
# Get all links that are relative to the root of the website
|
||||
all_links = re.findall(r"href=[\"\'](.*?)[\"\']", raw_html)
|
||||
absolute_paths = []
|
||||
invalid_prefixes = ("javascript:", "mailto:", "#")
|
||||
invalid_suffixes = (
|
||||
".css",
|
||||
".js",
|
||||
".ico",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
)
|
||||
# Process the links
|
||||
for link in all_links:
|
||||
# Ignore blacklisted patterns
|
||||
# like javascript: or mailto:, files of svg, ico, css, js
|
||||
if link.startswith(invalid_prefixes) or link.endswith(invalid_suffixes):
|
||||
continue
|
||||
# Some may be absolute links like https://to/path
|
||||
if link.startswith("http"):
|
||||
if (not self.prevent_outside) or (
|
||||
self.prevent_outside and link.startswith(base_url)
|
||||
):
|
||||
absolute_paths.append(link)
|
||||
else:
|
||||
absolute_paths.append(urljoin(base_url, link))
|
||||
|
||||
# Some may be relative links like /to/path
|
||||
if link.startswith("/") and not link.startswith("//"):
|
||||
absolute_paths.append(urljoin(base_url, link))
|
||||
continue
|
||||
# Some may have omitted the protocol like //to/path
|
||||
if link.startswith("//"):
|
||||
absolute_paths.append(f"{urlparse(base_url).scheme}:{link}")
|
||||
continue
|
||||
# Remove duplicates
|
||||
# also do another filter to prevent outside links
|
||||
absolute_paths = list(
|
||||
set(
|
||||
[
|
||||
path
|
||||
for path in absolute_paths
|
||||
if not self.prevent_outside
|
||||
or path.startswith(base_url)
|
||||
and path != base_url
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return absolute_paths
|
||||
|
||||
def _gen_metadata(self, raw_html: str, url: str) -> dict:
|
||||
"""Build metadata from BeautifulSoup output."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
print("The bs4 package is required for the RecursiveUrlLoader.")
|
||||
print("Please install it with `pip install bs4`.")
|
||||
metadata = {"source": url}
|
||||
soup = BeautifulSoup(raw_html, "html.parser")
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", None)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", None)
|
||||
return metadata
|
||||
self.link_regex = link_regex
|
||||
self._lock = asyncio.Lock() if self.use_async else None
|
||||
self.headers = headers
|
||||
|
||||
def _get_child_links_recursive(
|
||||
self, url: str, visited: Optional[Set[str]] = None, depth: int = 0
|
||||
self, url: str, visited: Set[str], *, depth: int = 0
|
||||
) -> Iterator[Document]:
|
||||
"""Recursively get all child links starting with the path of the input URL.
|
||||
|
||||
Args:
|
||||
url: The URL to crawl.
|
||||
visited: A set of visited URLs.
|
||||
depth: Current depth of recursion. Stop when depth >= max_depth.
|
||||
"""
|
||||
|
||||
if depth > self.max_depth:
|
||||
return []
|
||||
|
||||
# Add a trailing slash if not present
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
# Exclude the root and parent from a list
|
||||
visited = set() if visited is None else visited
|
||||
|
||||
if depth >= self.max_depth:
|
||||
return
|
||||
# Exclude the links that start with any of the excluded directories
|
||||
if self.exclude_dirs and any(
|
||||
url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs
|
||||
):
|
||||
return []
|
||||
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
|
||||
return
|
||||
|
||||
# Get all links that can be accessed from the current URL
|
||||
try:
|
||||
response = requests.get(url, timeout=self.timeout)
|
||||
response = requests.get(url, timeout=self.timeout, headers=self.headers)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
absolute_paths = self._get_sub_links(response.text, url)
|
||||
logger.warning(f"Unable to load from {url}")
|
||||
return
|
||||
content = self.extractor(response.text)
|
||||
if content:
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=self.metadata_extractor(response.text, url),
|
||||
)
|
||||
visited.add(url)
|
||||
|
||||
# Store the visited links and recursively visit the children
|
||||
for link in absolute_paths:
|
||||
sub_links = extract_sub_links(
|
||||
response.text,
|
||||
self.url,
|
||||
pattern=self.link_regex,
|
||||
prevent_outside=self.prevent_outside,
|
||||
)
|
||||
for link in sub_links:
|
||||
# Check all unvisited links
|
||||
if link not in visited:
|
||||
visited.add(link)
|
||||
|
||||
try:
|
||||
response = requests.get(link)
|
||||
text = response.text
|
||||
except Exception:
|
||||
# unreachable link, so just ignore it
|
||||
continue
|
||||
loaded_link = Document(
|
||||
page_content=self.extractor(text),
|
||||
metadata=self._gen_metadata(text, link),
|
||||
yield from self._get_child_links_recursive(
|
||||
link, visited, depth=depth + 1
|
||||
)
|
||||
yield loaded_link
|
||||
# If the link is a directory (w/ children) then visit it
|
||||
if link.endswith("/"):
|
||||
yield from self._get_child_links_recursive(link, visited, depth + 1)
|
||||
return []
|
||||
|
||||
async def _async_get_child_links_recursive(
|
||||
self, url: str, visited: Optional[Set[str]] = None, depth: int = 0
|
||||
self,
|
||||
url: str,
|
||||
visited: Set[str],
|
||||
*,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
depth: int = 0,
|
||||
) -> List[Document]:
|
||||
"""Recursively get all child links starting with the path of the input URL.
|
||||
|
||||
@ -193,117 +165,87 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
print("The aiohttp package is required for the RecursiveUrlLoader.")
|
||||
print("Please install it with `pip install aiohttp`.")
|
||||
if depth > self.max_depth:
|
||||
raise ImportError(
|
||||
"The aiohttp package is required for the RecursiveUrlLoader. "
|
||||
"Please install it with `pip install aiohttp`."
|
||||
)
|
||||
if depth >= self.max_depth:
|
||||
return []
|
||||
|
||||
# Add a trailing slash if not present
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
# Exclude the root and parent from a list
|
||||
visited = set() if visited is None else visited
|
||||
|
||||
# Exclude the links that start with any of the excluded directories
|
||||
if self.exclude_dirs and any(
|
||||
url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs
|
||||
):
|
||||
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
|
||||
return []
|
||||
# Disable SSL verification because websites may have invalid SSL certificates,
|
||||
# but won't cause any security issues for us.
|
||||
async with aiohttp.ClientSession(
|
||||
close_session = session is None
|
||||
session = session or aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
timeout=aiohttp.ClientTimeout(self.timeout),
|
||||
) as session:
|
||||
# Some url may be invalid, so catch the exception
|
||||
response: aiohttp.ClientResponse
|
||||
try:
|
||||
response = await session.get(url)
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
headers=self.headers,
|
||||
)
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
text = await response.text()
|
||||
except aiohttp.client_exceptions.InvalidURL:
|
||||
return []
|
||||
# There may be some other exceptions, so catch them,
|
||||
# we don't want to stop the whole process
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
absolute_paths = self._get_sub_links(text, url)
|
||||
|
||||
# Worker will be only called within the current function
|
||||
# Worker function will process the link
|
||||
# then recursively call get_child_links_recursive to process the children
|
||||
async def worker(link: str) -> Union[Document, None]:
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
timeout=aiohttp.ClientTimeout(self.timeout),
|
||||
) as session:
|
||||
response = await session.get(link)
|
||||
text = await response.text()
|
||||
extracted = self.extractor(text)
|
||||
if len(extracted) > 0:
|
||||
return Document(
|
||||
page_content=extracted,
|
||||
metadata=self._gen_metadata(text, link),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
# Despite the fact that we have filtered some links,
|
||||
# there may still be some invalid links, so catch the exception
|
||||
except aiohttp.client_exceptions.InvalidURL:
|
||||
return None
|
||||
# There may be some other exceptions, so catch them,
|
||||
# we don't want to stop the whole process
|
||||
except Exception:
|
||||
# print(e)
|
||||
return None
|
||||
|
||||
# The coroutines that will be executed
|
||||
tasks = []
|
||||
# Generate the tasks
|
||||
for link in absolute_paths:
|
||||
# Check all unvisited links
|
||||
if link not in visited:
|
||||
visited.add(link)
|
||||
tasks.append(worker(link))
|
||||
# Get the not None results
|
||||
results = list(
|
||||
filter(lambda x: x is not None, await asyncio.gather(*tasks))
|
||||
async with self._lock: # type: ignore
|
||||
visited.add(url)
|
||||
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
|
||||
logger.warning(
|
||||
f"Unable to load {url}. Received error {e} of type "
|
||||
f"{e.__class__.__name__}"
|
||||
)
|
||||
return []
|
||||
results = []
|
||||
content = self.extractor(text)
|
||||
if content:
|
||||
results.append(
|
||||
Document(
|
||||
page_content=content,
|
||||
metadata=self.metadata_extractor(text, url),
|
||||
)
|
||||
)
|
||||
if depth < self.max_depth - 1:
|
||||
sub_links = extract_sub_links(
|
||||
text,
|
||||
self.url,
|
||||
pattern=self.link_regex,
|
||||
prevent_outside=self.prevent_outside,
|
||||
)
|
||||
|
||||
# Recursively call the function to get the children of the children
|
||||
sub_tasks = []
|
||||
for link in absolute_paths:
|
||||
sub_tasks.append(
|
||||
self._async_get_child_links_recursive(link, visited, depth + 1)
|
||||
)
|
||||
# sub_tasks returns coroutines of list,
|
||||
# so we need to flatten the list await asyncio.gather(*sub_tasks)
|
||||
flattened = []
|
||||
async with self._lock: # type: ignore
|
||||
to_visit = set(sub_links).difference(visited)
|
||||
for link in to_visit:
|
||||
sub_tasks.append(
|
||||
self._async_get_child_links_recursive(
|
||||
link, visited, session=session, depth=depth + 1
|
||||
)
|
||||
)
|
||||
next_results = await asyncio.gather(*sub_tasks)
|
||||
for sub_result in next_results:
|
||||
if isinstance(sub_result, Exception):
|
||||
if isinstance(sub_result, Exception) or sub_result is None:
|
||||
# We don't want to stop the whole process, so just ignore it
|
||||
# Not standard html format or invalid url or 404 may cause this
|
||||
# But we can't do anything about it.
|
||||
# Not standard html format or invalid url or 404 may cause this.
|
||||
continue
|
||||
if sub_result is not None:
|
||||
flattened += sub_result
|
||||
results += flattened
|
||||
return list(filter(lambda x: x is not None, results))
|
||||
# locking not fully working, temporary hack to ensure deduplication
|
||||
results += [r for r in sub_result if r not in results]
|
||||
if close_session:
|
||||
await session.close()
|
||||
return results
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load web pages.
|
||||
When use_async is True, this function will not be lazy,
|
||||
but it will still work in the expected way, just not lazy."""
|
||||
visited: Set[str] = set()
|
||||
if self.use_async:
|
||||
results = asyncio.run(self._async_get_child_links_recursive(self.url))
|
||||
if results is None:
|
||||
return iter([])
|
||||
else:
|
||||
return iter(results)
|
||||
results = asyncio.run(
|
||||
self._async_get_child_links_recursive(self.url, visited)
|
||||
)
|
||||
return iter(results or [])
|
||||
else:
|
||||
return self._get_child_links_recursive(self.url)
|
||||
return self._get_child_links_recursive(self.url, visited)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load web pages."""
|
||||
|
69
libs/langchain/langchain/utils/html.py
Normal file
69
libs/langchain/langchain/utils/html.py
Normal file
@ -0,0 +1,69 @@
|
||||
import re
|
||||
from typing import List, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
PREFIXES_TO_IGNORE = ("javascript:", "mailto:", "#")
|
||||
SUFFIXES_TO_IGNORE = (
|
||||
".css",
|
||||
".js",
|
||||
".ico",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".csv",
|
||||
".bz2",
|
||||
".zip",
|
||||
".epub",
|
||||
)
|
||||
SUFFIXES_TO_IGNORE_REGEX = (
|
||||
"(?!" + "|".join([re.escape(s) + "[\#'\"]" for s in SUFFIXES_TO_IGNORE]) + ")"
|
||||
)
|
||||
PREFIXES_TO_IGNORE_REGEX = (
|
||||
"(?!" + "|".join([re.escape(s) for s in PREFIXES_TO_IGNORE]) + ")"
|
||||
)
|
||||
DEFAULT_LINK_REGEX = (
|
||||
f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)[\#'\"]"
|
||||
)
|
||||
|
||||
|
||||
def find_all_links(
|
||||
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
||||
) -> List[str]:
|
||||
pattern = pattern or DEFAULT_LINK_REGEX
|
||||
return list(set(re.findall(pattern, raw_html)))
|
||||
|
||||
|
||||
def extract_sub_links(
|
||||
raw_html: str,
|
||||
base_url: str,
|
||||
*,
|
||||
pattern: Union[str, re.Pattern, None] = None,
|
||||
prevent_outside: bool = True,
|
||||
) -> List[str]:
|
||||
"""Extract all links from a raw html string and convert into absolute paths.
|
||||
|
||||
Args:
|
||||
raw_html: original html
|
||||
base_url: the base url of the html
|
||||
pattern: Regex to use for extracting links from raw html.
|
||||
prevent_outside: If True, ignore external links which are not children
|
||||
of the base url.
|
||||
|
||||
Returns:
|
||||
List[str]: sub links
|
||||
"""
|
||||
all_links = find_all_links(raw_html, pattern=pattern)
|
||||
absolute_paths = set()
|
||||
for link in all_links:
|
||||
# Some may be absolute links like https://to/path
|
||||
if link.startswith("http"):
|
||||
if not prevent_outside or link.startswith(base_url):
|
||||
absolute_paths.add(link)
|
||||
# Some may have omitted the protocol like //to/path
|
||||
elif link.startswith("//"):
|
||||
absolute_paths.add(f"{urlparse(base_url).scheme}:{link}")
|
||||
else:
|
||||
absolute_paths.add(urljoin(base_url, link))
|
||||
return list(absolute_paths)
|
@ -1,30 +1,61 @@
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_async_recursive_url_loader() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
loader = RecursiveUrlLoader(
|
||||
url=url, extractor=lambda _: "placeholder", use_async=True, max_depth=1
|
||||
url,
|
||||
extractor=lambda _: "placeholder",
|
||||
use_async=True,
|
||||
max_depth=3,
|
||||
timeout=None,
|
||||
)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 24
|
||||
assert len(docs) == 1024
|
||||
assert docs[0].page_content == "placeholder"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_async_recursive_url_loader_deterministic() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
loader = RecursiveUrlLoader(
|
||||
url,
|
||||
use_async=True,
|
||||
max_depth=3,
|
||||
timeout=None,
|
||||
)
|
||||
docs = sorted(loader.load(), key=lambda d: d.metadata["source"])
|
||||
docs_2 = sorted(loader.load(), key=lambda d: d.metadata["source"])
|
||||
assert docs == docs_2
|
||||
|
||||
|
||||
def test_sync_recursive_url_loader() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
loader = RecursiveUrlLoader(
|
||||
url=url, extractor=lambda _: "placeholder", use_async=False, max_depth=1
|
||||
url, extractor=lambda _: "placeholder", use_async=False, max_depth=2
|
||||
)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 24
|
||||
assert len(docs) == 27
|
||||
assert docs[0].page_content == "placeholder"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_sync_async_equivalent() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
loader = RecursiveUrlLoader(url, use_async=False, max_depth=2)
|
||||
async_loader = RecursiveUrlLoader(url, use_async=False, max_depth=2)
|
||||
docs = sorted(loader.load(), key=lambda d: d.metadata["source"])
|
||||
async_docs = sorted(async_loader.load(), key=lambda d: d.metadata["source"])
|
||||
assert docs == async_docs
|
||||
|
||||
|
||||
def test_loading_invalid_url() -> None:
|
||||
url = "https://this.url.is.invalid/this/is/a/test"
|
||||
loader = RecursiveUrlLoader(
|
||||
url=url, max_depth=1, extractor=lambda _: "placeholder", use_async=False
|
||||
url, max_depth=1, extractor=lambda _: "placeholder", use_async=False
|
||||
)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 0
|
||||
|
109
libs/langchain/tests/unit_tests/utils/test_html.py
Normal file
109
libs/langchain/tests/unit_tests/utils/test_html.py
Normal file
@ -0,0 +1,109 @@
|
||||
from langchain.utils.html import (
|
||||
PREFIXES_TO_IGNORE,
|
||||
SUFFIXES_TO_IGNORE,
|
||||
extract_sub_links,
|
||||
find_all_links,
|
||||
)
|
||||
|
||||
|
||||
def test_find_all_links_none() -> None:
|
||||
html = "<span>Hello world</span>"
|
||||
actual = find_all_links(html)
|
||||
assert actual == []
|
||||
|
||||
|
||||
def test_find_all_links_single() -> None:
|
||||
htmls = [
|
||||
"href='foobar.com'",
|
||||
'href="foobar.com"',
|
||||
'<div><a class="blah" href="foobar.com">hullo</a></div>',
|
||||
]
|
||||
actual = [find_all_links(html) for html in htmls]
|
||||
assert actual == [["foobar.com"]] * 3
|
||||
|
||||
|
||||
def test_find_all_links_multiple() -> None:
|
||||
html = (
|
||||
'<div><a class="blah" href="https://foobar.com">hullo</a></div>'
|
||||
'<div><a class="bleh" href="/baz/cool">buhbye</a></div>'
|
||||
)
|
||||
actual = find_all_links(html)
|
||||
assert sorted(actual) == [
|
||||
"/baz/cool",
|
||||
"https://foobar.com",
|
||||
]
|
||||
|
||||
|
||||
def test_find_all_links_ignore_suffix() -> None:
|
||||
html = 'href="foobar{suffix}"'
|
||||
for suffix in SUFFIXES_TO_IGNORE:
|
||||
actual = find_all_links(html.format(suffix=suffix))
|
||||
assert actual == []
|
||||
|
||||
# Don't ignore if pattern doesn't occur at end of link.
|
||||
html = 'href="foobar{suffix}more"'
|
||||
for suffix in SUFFIXES_TO_IGNORE:
|
||||
actual = find_all_links(html.format(suffix=suffix))
|
||||
assert actual == [f"foobar{suffix}more"]
|
||||
|
||||
|
||||
def test_find_all_links_ignore_prefix() -> None:
|
||||
html = 'href="{prefix}foobar"'
|
||||
for prefix in PREFIXES_TO_IGNORE:
|
||||
actual = find_all_links(html.format(prefix=prefix))
|
||||
assert actual == []
|
||||
|
||||
# Don't ignore if pattern doesn't occur at beginning of link.
|
||||
html = 'href="foobar{prefix}more"'
|
||||
for prefix in PREFIXES_TO_IGNORE:
|
||||
# Pound signs are split on when not prefixes.
|
||||
if prefix == "#":
|
||||
continue
|
||||
actual = find_all_links(html.format(prefix=prefix))
|
||||
assert actual == [f"foobar{prefix}more"]
|
||||
|
||||
|
||||
def test_find_all_links_drop_fragment() -> None:
|
||||
html = 'href="foobar.com/woah#section_one"'
|
||||
actual = find_all_links(html)
|
||||
assert actual == ["foobar.com/woah"]
|
||||
|
||||
|
||||
def test_extract_sub_links() -> None:
|
||||
html = (
|
||||
'<a href="https://foobar.com">one</a>'
|
||||
'<a href="http://baz.net">two</a>'
|
||||
'<a href="//foobar.com/hello">three</a>'
|
||||
'<a href="/how/are/you/doing">four</a>'
|
||||
)
|
||||
expected = sorted(
|
||||
[
|
||||
"https://foobar.com",
|
||||
"https://foobar.com/hello",
|
||||
"https://foobar.com/how/are/you/doing",
|
||||
]
|
||||
)
|
||||
actual = sorted(extract_sub_links(html, "https://foobar.com"))
|
||||
assert actual == expected
|
||||
|
||||
actual = sorted(extract_sub_links(html, "https://foobar.com/hello"))
|
||||
expected = sorted(
|
||||
[
|
||||
"https://foobar.com/hello",
|
||||
"https://foobar.com/how/are/you/doing",
|
||||
]
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
actual = sorted(
|
||||
extract_sub_links(html, "https://foobar.com/hello", prevent_outside=False)
|
||||
)
|
||||
expected = sorted(
|
||||
[
|
||||
"https://foobar.com",
|
||||
"http://baz.net",
|
||||
"https://foobar.com/hello",
|
||||
"https://foobar.com/how/are/you/doing",
|
||||
]
|
||||
)
|
||||
assert actual == expected
|
Loading…
Reference in New Issue
Block a user