community[patch]: recursive url loader fix and unit tests (#22521)

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Isaac Francisco 2024-06-05 17:56:20 -07:00 committed by GitHub
parent 234394f631
commit ba3e219d83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 108 additions and 11 deletions

View File

@ -134,6 +134,7 @@ class RecursiveUrlLoader(BaseLoader):
prevent_outside: If True, prevent loading from urls which are not children
of the root url.
link_regex: Regex for extracting sub-links from the raw html of a web page.
headers: Default request headers to use for all requests.
check_response_status: If True, check HTTP response status and skip
URLs with error responses (400-599).
continue_on_failure: If True, continue if getting or parsing a link raises
@ -169,7 +170,6 @@ class RecursiveUrlLoader(BaseLoader):
self.timeout = timeout
self.prevent_outside = prevent_outside if prevent_outside is not None else True
self.link_regex = link_regex
self._lock = asyncio.Lock() if self.use_async else None
self.headers = headers
self.check_response_status = check_response_status
self.continue_on_failure = continue_on_failure
@ -249,7 +249,7 @@ class RecursiveUrlLoader(BaseLoader):
visited: A set of visited URLs.
depth: To reach the current url, how many pages have been visited.
"""
if not self.use_async or not self._lock:
if not self.use_async:
raise ValueError(
"Async functions forbidden when not initialized with `use_async`"
)
@ -276,8 +276,7 @@ class RecursiveUrlLoader(BaseLoader):
headers=self.headers,
)
)
async with self._lock:
visited.add(url)
visited.add(url)
try:
async with session.get(url) as response:
text = await response.text()
@ -316,14 +315,13 @@ class RecursiveUrlLoader(BaseLoader):
# Recursively call the function to get the children of the children
sub_tasks = []
async with self._lock:
to_visit = set(sub_links).difference(visited)
for link in to_visit:
sub_tasks.append(
self._async_get_child_links_recursive(
link, visited, session=session, depth=depth + 1
)
to_visit = set(sub_links).difference(visited)
for link in to_visit:
sub_tasks.append(
self._async_get_child_links_recursive(
link, visited, session=session, depth=depth + 1
)
)
next_results = await asyncio.gather(*sub_tasks)
for sub_result in next_results:
if isinstance(sub_result, Exception) or sub_result is None:

View File

@ -0,0 +1,99 @@
from __future__ import annotations
import inspect
import uuid
from types import TracebackType
from typing import Any, Type
import aiohttp
import pytest
import requests_mock
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
link_to_one_two = """
<div><a href="/one">link_to_one</a></div>
<div><a href="/two">link_to_two</a></div>
"""
link_to_three = '<div><a href="../three">link_to_three</a></div>'
no_links = "<p>no links<p>"
fake_url = f"https://{uuid.uuid4()}.com"
URL_TO_HTML = {
fake_url: link_to_one_two,
f"{fake_url}/one": link_to_three,
f"{fake_url}/two": link_to_three,
f"{fake_url}/three": no_links,
}
class MockGet:
def __init__(self, url: str) -> None:
self._text = URL_TO_HTML[url]
self.headers: dict = {}
async def text(self) -> str:
return self._text
async def __aexit__(
self, exc_type: Type[BaseException], exc: BaseException, tb: TracebackType
) -> None:
pass
async def __aenter__(self) -> MockGet:
return self
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
@pytest.mark.parametrize("use_async", [False, True])
def test_lazy_load(
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
) -> None:
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
if use_async:
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
docs = list(loader.lazy_load())
else:
with requests_mock.Mocker() as m:
for url, html in URL_TO_HTML.items():
m.get(url, text=html)
docs = list(loader.lazy_load())
assert len(docs) == expected_docs
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
@pytest.mark.parametrize("use_async", [False, True])
async def test_alazy_load(
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
) -> None:
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
if use_async:
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
docs = []
async for doc in loader.alazy_load():
docs.append(doc)
else:
with requests_mock.Mocker() as m:
for url, html in URL_TO_HTML.items():
m.get(url, text=html)
docs = []
async for doc in loader.alazy_load():
docs.append(doc)
assert len(docs) == expected_docs
def test_init_args_documented() -> None:
cls_docstring = RecursiveUrlLoader.__doc__ or ""
init_docstring = RecursiveUrlLoader.__init__.__doc__ or ""
all_docstring = cls_docstring + init_docstring
init_args = list(inspect.signature(RecursiveUrlLoader.__init__).parameters)
undocumented = [arg for arg in init_args[1:] if f"{arg}:" not in all_docstring]
assert not undocumented
@pytest.mark.parametrize("method", ["load", "aload", "lazy_load", "alazy_load"])
def test_no_runtime_args(method: str) -> None:
method_attr = getattr(RecursiveUrlLoader, method)
args = list(inspect.signature(method_attr).parameters)
assert args == ["self"]