mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
community[patch]: recursive url loader fix and unit tests (#22521)
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
234394f631
commit
ba3e219d83
@ -134,6 +134,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
prevent_outside: If True, prevent loading from urls which are not children
|
||||
of the root url.
|
||||
link_regex: Regex for extracting sub-links from the raw html of a web page.
|
||||
headers: Default request headers to use for all requests.
|
||||
check_response_status: If True, check HTTP response status and skip
|
||||
URLs with error responses (400-599).
|
||||
continue_on_failure: If True, continue if getting or parsing a link raises
|
||||
@ -169,7 +170,6 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
self.timeout = timeout
|
||||
self.prevent_outside = prevent_outside if prevent_outside is not None else True
|
||||
self.link_regex = link_regex
|
||||
self._lock = asyncio.Lock() if self.use_async else None
|
||||
self.headers = headers
|
||||
self.check_response_status = check_response_status
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@ -249,7 +249,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
visited: A set of visited URLs.
|
||||
depth: To reach the current url, how many pages have been visited.
|
||||
"""
|
||||
if not self.use_async or not self._lock:
|
||||
if not self.use_async:
|
||||
raise ValueError(
|
||||
"Async functions forbidden when not initialized with `use_async`"
|
||||
)
|
||||
@ -276,8 +276,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
headers=self.headers,
|
||||
)
|
||||
)
|
||||
async with self._lock:
|
||||
visited.add(url)
|
||||
visited.add(url)
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
text = await response.text()
|
||||
@ -316,14 +315,13 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
|
||||
# Recursively call the function to get the children of the children
|
||||
sub_tasks = []
|
||||
async with self._lock:
|
||||
to_visit = set(sub_links).difference(visited)
|
||||
for link in to_visit:
|
||||
sub_tasks.append(
|
||||
self._async_get_child_links_recursive(
|
||||
link, visited, session=session, depth=depth + 1
|
||||
)
|
||||
to_visit = set(sub_links).difference(visited)
|
||||
for link in to_visit:
|
||||
sub_tasks.append(
|
||||
self._async_get_child_links_recursive(
|
||||
link, visited, session=session, depth=depth + 1
|
||||
)
|
||||
)
|
||||
next_results = await asyncio.gather(*sub_tasks)
|
||||
for sub_result in next_results:
|
||||
if isinstance(sub_result, Exception) or sub_result is None:
|
||||
|
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any, Type
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import requests_mock
|
||||
|
||||
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||
|
||||
link_to_one_two = """
|
||||
<div><a href="/one">link_to_one</a></div>
|
||||
<div><a href="/two">link_to_two</a></div>
|
||||
"""
|
||||
link_to_three = '<div><a href="../three">link_to_three</a></div>'
|
||||
no_links = "<p>no links<p>"
|
||||
|
||||
fake_url = f"https://{uuid.uuid4()}.com"
|
||||
URL_TO_HTML = {
|
||||
fake_url: link_to_one_two,
|
||||
f"{fake_url}/one": link_to_three,
|
||||
f"{fake_url}/two": link_to_three,
|
||||
f"{fake_url}/three": no_links,
|
||||
}
|
||||
|
||||
|
||||
class MockGet:
|
||||
def __init__(self, url: str) -> None:
|
||||
self._text = URL_TO_HTML[url]
|
||||
self.headers: dict = {}
|
||||
|
||||
async def text(self) -> str:
|
||||
return self._text
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Type[BaseException], exc: BaseException, tb: TracebackType
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> MockGet:
|
||||
return self
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_lazy_load(
|
||||
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
|
||||
) -> None:
|
||||
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
|
||||
if use_async:
|
||||
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
|
||||
docs = list(loader.lazy_load())
|
||||
else:
|
||||
with requests_mock.Mocker() as m:
|
||||
for url, html in URL_TO_HTML.items():
|
||||
m.get(url, text=html)
|
||||
docs = list(loader.lazy_load())
|
||||
assert len(docs) == expected_docs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_alazy_load(
|
||||
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
|
||||
) -> None:
|
||||
loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async)
|
||||
if use_async:
|
||||
mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet)
|
||||
docs = []
|
||||
async for doc in loader.alazy_load():
|
||||
docs.append(doc)
|
||||
else:
|
||||
with requests_mock.Mocker() as m:
|
||||
for url, html in URL_TO_HTML.items():
|
||||
m.get(url, text=html)
|
||||
docs = []
|
||||
async for doc in loader.alazy_load():
|
||||
docs.append(doc)
|
||||
|
||||
assert len(docs) == expected_docs
|
||||
|
||||
|
||||
def test_init_args_documented() -> None:
|
||||
cls_docstring = RecursiveUrlLoader.__doc__ or ""
|
||||
init_docstring = RecursiveUrlLoader.__init__.__doc__ or ""
|
||||
all_docstring = cls_docstring + init_docstring
|
||||
init_args = list(inspect.signature(RecursiveUrlLoader.__init__).parameters)
|
||||
undocumented = [arg for arg in init_args[1:] if f"{arg}:" not in all_docstring]
|
||||
assert not undocumented
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["load", "aload", "lazy_load", "alazy_load"])
|
||||
def test_no_runtime_args(method: str) -> None:
|
||||
method_attr = getattr(RecursiveUrlLoader, method)
|
||||
args = list(inspect.signature(method_attr).parameters)
|
||||
assert args == ["self"]
|
Loading…
Reference in New Issue
Block a user