mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
community[patch]: recursive url loader fix and unit tests (#22521)
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
234394f631
commit
ba3e219d83
@ -134,6 +134,7 @@ class RecursiveUrlLoader(BaseLoader):
|
|||||||
prevent_outside: If True, prevent loading from urls which are not children
|
prevent_outside: If True, prevent loading from urls which are not children
|
||||||
of the root url.
|
of the root url.
|
||||||
link_regex: Regex for extracting sub-links from the raw html of a web page.
|
link_regex: Regex for extracting sub-links from the raw html of a web page.
|
||||||
|
headers: Default request headers to use for all requests.
|
||||||
check_response_status: If True, check HTTP response status and skip
|
check_response_status: If True, check HTTP response status and skip
|
||||||
URLs with error responses (400-599).
|
URLs with error responses (400-599).
|
||||||
continue_on_failure: If True, continue if getting or parsing a link raises
|
continue_on_failure: If True, continue if getting or parsing a link raises
|
||||||
@ -169,7 +170,6 @@ class RecursiveUrlLoader(BaseLoader):
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.prevent_outside = prevent_outside if prevent_outside is not None else True
|
self.prevent_outside = prevent_outside if prevent_outside is not None else True
|
||||||
self.link_regex = link_regex
|
self.link_regex = link_regex
|
||||||
self._lock = asyncio.Lock() if self.use_async else None
|
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.check_response_status = check_response_status
|
self.check_response_status = check_response_status
|
||||||
self.continue_on_failure = continue_on_failure
|
self.continue_on_failure = continue_on_failure
|
||||||
@ -249,7 +249,7 @@ class RecursiveUrlLoader(BaseLoader):
|
|||||||
visited: A set of visited URLs.
|
visited: A set of visited URLs.
|
||||||
depth: To reach the current url, how many pages have been visited.
|
depth: To reach the current url, how many pages have been visited.
|
||||||
"""
|
"""
|
||||||
if not self.use_async or not self._lock:
|
if not self.use_async:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Async functions forbidden when not initialized with `use_async`"
|
"Async functions forbidden when not initialized with `use_async`"
|
||||||
)
|
)
|
||||||
@ -276,8 +276,7 @@ class RecursiveUrlLoader(BaseLoader):
|
|||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async with self._lock:
|
visited.add(url)
|
||||||
visited.add(url)
|
|
||||||
try:
|
try:
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
@ -316,14 +315,13 @@ class RecursiveUrlLoader(BaseLoader):
|
|||||||
|
|
||||||
# Recursively call the function to get the children of the children
|
# Recursively call the function to get the children of the children
|
||||||
sub_tasks = []
|
sub_tasks = []
|
||||||
async with self._lock:
|
to_visit = set(sub_links).difference(visited)
|
||||||
to_visit = set(sub_links).difference(visited)
|
for link in to_visit:
|
||||||
for link in to_visit:
|
sub_tasks.append(
|
||||||
sub_tasks.append(
|
self._async_get_child_links_recursive(
|
||||||
self._async_get_child_links_recursive(
|
link, visited, session=session, depth=depth + 1
|
||||||
link, visited, session=session, depth=depth + 1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
next_results = await asyncio.gather(*sub_tasks)
|
next_results = await asyncio.gather(*sub_tasks)
|
||||||
for sub_result in next_results:
|
for sub_result in next_results:
|
||||||
if isinstance(sub_result, Exception) or sub_result is None:
|
if isinstance(sub_result, Exception) or sub_result is None:
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
import requests_mock
|
||||||
|
|
||||||
|
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||||
|
|
||||||
|
link_to_one_two = """
|
||||||
|
<div><a href="/one">link_to_one</a></div>
|
||||||
|
<div><a href="/two">link_to_two</a></div>
|
||||||
|
"""
|
||||||
|
link_to_three = '<div><a href="../three">link_to_three</a></div>'
|
||||||
|
no_links = "<p>no links<p>"
|
||||||
|
|
||||||
|
fake_url = f"https://{uuid.uuid4()}.com"
|
||||||
|
URL_TO_HTML = {
|
||||||
|
fake_url: link_to_one_two,
|
||||||
|
f"{fake_url}/one": link_to_three,
|
||||||
|
f"{fake_url}/two": link_to_three,
|
||||||
|
f"{fake_url}/three": no_links,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockGet:
|
||||||
|
def __init__(self, url: str) -> None:
|
||||||
|
self._text = URL_TO_HTML[url]
|
||||||
|
self.headers: dict = {}
|
||||||
|
|
||||||
|
async def text(self) -> str:
|
||||||
|
return self._text
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self, exc_type: Type[BaseException], exc: BaseException, tb: TracebackType
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self) -> MockGet:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
|
||||||
|
@pytest.mark.parametrize("use_async", [False, True])
|
||||||
|
def test_lazy_load(
|
||||||
|
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
|
||||||
|
) -> None:
|
||||||
|
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
|
||||||
|
if use_async:
|
||||||
|
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
|
||||||
|
docs = list(loader.lazy_load())
|
||||||
|
else:
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
for url, html in URL_TO_HTML.items():
|
||||||
|
m.get(url, text=html)
|
||||||
|
docs = list(loader.lazy_load())
|
||||||
|
assert len(docs) == expected_docs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
|
||||||
|
@pytest.mark.parametrize("use_async", [False, True])
|
||||||
|
async def test_alazy_load(
|
||||||
|
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
|
||||||
|
) -> None:
|
||||||
|
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
|
||||||
|
if use_async:
|
||||||
|
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
|
||||||
|
docs = []
|
||||||
|
async for doc in loader.alazy_load():
|
||||||
|
docs.append(doc)
|
||||||
|
else:
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
for url, html in URL_TO_HTML.items():
|
||||||
|
m.get(url, text=html)
|
||||||
|
docs = []
|
||||||
|
async for doc in loader.alazy_load():
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
assert len(docs) == expected_docs
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_args_documented() -> None:
|
||||||
|
cls_docstring = RecursiveUrlLoader.__doc__ or ""
|
||||||
|
init_docstring = RecursiveUrlLoader.__init__.__doc__ or ""
|
||||||
|
all_docstring = cls_docstring + init_docstring
|
||||||
|
init_args = list(inspect.signature(RecursiveUrlLoader.__init__).parameters)
|
||||||
|
undocumented = [arg for arg in init_args[1:] if f"{arg}:" not in all_docstring]
|
||||||
|
assert not undocumented
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("method", ["load", "aload", "lazy_load", "alazy_load"])
|
||||||
|
def test_no_runtime_args(method: str) -> None:
|
||||||
|
method_attr = getattr(RecursiveUrlLoader, method)
|
||||||
|
args = list(inspect.signature(method_attr).parameters)
|
||||||
|
assert args == ["self"]
|
Loading…
Reference in New Issue
Block a user