From ba3e219d835686690141b715e9717493d4adda6d Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:56:20 -0700 Subject: [PATCH] community[patch]: recursive url loader fix and unit tests (#22521) Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../document_loaders/recursive_url_loader.py | 20 ++-- .../test_recursive_url_loader.py | 99 +++++++++++++++++++ 2 files changed, 108 insertions(+), 11 deletions(-) create mode 100644 libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py diff --git a/libs/community/langchain_community/document_loaders/recursive_url_loader.py b/libs/community/langchain_community/document_loaders/recursive_url_loader.py index b5e9bfabcfa..83e62e1d896 100644 --- a/libs/community/langchain_community/document_loaders/recursive_url_loader.py +++ b/libs/community/langchain_community/document_loaders/recursive_url_loader.py @@ -134,6 +134,7 @@ class RecursiveUrlLoader(BaseLoader): prevent_outside: If True, prevent loading from urls which are not children of the root url. link_regex: Regex for extracting sub-links from the raw html of a web page. + headers: Default request headers to use for all requests. check_response_status: If True, check HTTP response status and skip URLs with error responses (400-599). continue_on_failure: If True, continue if getting or parsing a link raises @@ -169,7 +170,6 @@ class RecursiveUrlLoader(BaseLoader): self.timeout = timeout self.prevent_outside = prevent_outside if prevent_outside is not None else True self.link_regex = link_regex - self._lock = asyncio.Lock() if self.use_async else None self.headers = headers self.check_response_status = check_response_status self.continue_on_failure = continue_on_failure @@ -249,7 +249,7 @@ class RecursiveUrlLoader(BaseLoader): visited: A set of visited URLs. depth: To reach the current url, how many pages have been visited. """ - if not self.use_async or not self._lock: + if not self.use_async: raise ValueError( "Async functions forbidden when not initialized with `use_async`" ) @@ -276,8 +276,7 @@ class RecursiveUrlLoader(BaseLoader): headers=self.headers, ) ) - async with self._lock: - visited.add(url) + visited.add(url) try: async with session.get(url) as response: text = await response.text() @@ -316,14 +315,13 @@ class RecursiveUrlLoader(BaseLoader): # Recursively call the function to get the children of the children sub_tasks = [] - async with self._lock: - to_visit = set(sub_links).difference(visited) - for link in to_visit: - sub_tasks.append( - self._async_get_child_links_recursive( - link, visited, session=session, depth=depth + 1 - ) + to_visit = set(sub_links).difference(visited) + for link in to_visit: + sub_tasks.append( + self._async_get_child_links_recursive( + link, visited, session=session, depth=depth + 1 ) + ) next_results = await asyncio.gather(*sub_tasks) for sub_result in next_results: if isinstance(sub_result, Exception) or sub_result is None: diff --git a/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py b/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py new file mode 100644 index 00000000000..55e00d99765 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import inspect +import uuid +from types import TracebackType +from typing import Any, Type + +import aiohttp +import pytest +import requests_mock + +from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader + +link_to_one_two = """ +
link_to_one
+
link_to_two
+""" +link_to_three = '
link_to_three
' +no_links = "

no links

" + +fake_url = f"https://{uuid.uuid4()}.com" +URL_TO_HTML = { + fake_url: link_to_one_two, + f"{fake_url}/one": link_to_three, + f"{fake_url}/two": link_to_three, + f"{fake_url}/three": no_links, +} + + +class MockGet: + def __init__(self, url: str) -> None: + self._text = URL_TO_HTML[url] + self.headers: dict = {} + + async def text(self) -> str: + return self._text + + async def __aexit__( + self, exc_type: Type[BaseException], exc: BaseException, tb: TracebackType + ) -> None: + pass + + async def __aenter__(self) -> MockGet: + return self + + +@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)]) +@pytest.mark.parametrize("use_async", [False, True]) +def test_lazy_load( + mocker: Any, max_depth: int, expected_docs: int, use_async: bool +) -> None: + loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async) + if use_async: + mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet) + docs = list(loader.lazy_load()) + else: + with requests_mock.Mocker() as m: + for url, html in URL_TO_HTML.items(): + m.get(url, text=html) + docs = list(loader.lazy_load()) + assert len(docs) == expected_docs + + +@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_alazy_load( + mocker: Any, max_depth: int, expected_docs: int, use_async: bool +) -> None: + loader = RecursiveUrlLoader(fake_url, max_depth=max_depth, use_async=use_async) + if use_async: + mocker.patch.object(aiohttp.ClientSession, "get", new=MockGet) + docs = [] + async for doc in loader.alazy_load(): + docs.append(doc) + else: + with requests_mock.Mocker() as m: + for url, html in URL_TO_HTML.items(): + m.get(url, text=html) + docs = [] + async for doc in loader.alazy_load(): + docs.append(doc) + + assert len(docs) == expected_docs + + +def test_init_args_documented() -> None: + cls_docstring = RecursiveUrlLoader.__doc__ or "" + init_docstring = RecursiveUrlLoader.__init__.__doc__ or "" + all_docstring = cls_docstring + init_docstring + init_args = list(inspect.signature(RecursiveUrlLoader.__init__).parameters) + undocumented = [arg for arg in init_args[1:] if f"{arg}:" not in all_docstring] + assert not undocumented + + +@pytest.mark.parametrize("method", ["load", "aload", "lazy_load", "alazy_load"]) +def test_no_runtime_args(method: str) -> None: + method_attr = getattr(RecursiveUrlLoader, method) + args = list(inspect.signature(method_attr).parameters) + assert args == ["self"]