diff --git a/libs/community/langchain_community/document_loaders/recursive_url_loader.py b/libs/community/langchain_community/document_loaders/recursive_url_loader.py index 83e62e1d896..60f3bb70267 100644 --- a/libs/community/langchain_community/document_loaders/recursive_url_loader.py +++ b/libs/community/langchain_community/document_loaders/recursive_url_loader.py @@ -5,6 +5,7 @@ import inspect import logging import re from typing import ( + AsyncIterator, Callable, Iterator, List, @@ -86,6 +87,7 @@ class RecursiveUrlLoader(BaseLoader): self, url: str, max_depth: Optional[int] = 2, + # TODO: Deprecate use_async use_async: Optional[bool] = None, extractor: Optional[Callable[[str], str]] = None, metadata_extractor: Optional[_MetadataExtractorType] = None, @@ -100,6 +102,7 @@ class RecursiveUrlLoader(BaseLoader): base_url: Optional[str] = None, autoset_encoding: bool = True, encoding: Optional[str] = None, + retries: int = 0, ) -> None: """Initialize with URL to crawl and any subdirectories to exclude. @@ -147,26 +150,25 @@ class RecursiveUrlLoader(BaseLoader): set to given value, regardless of the `autoset_encoding` argument. """ # noqa: E501 + if exclude_dirs and any(url.startswith(dir) for dir in exclude_dirs): + raise ValueError( + f"Base url is included in exclude_dirs. Received base_url: {url} and " + f"exclude_dirs: {exclude_dirs}" + ) + if max_depth is not None and max_depth < 1: + raise ValueError("max_depth must be positive.") + self.url = url self.max_depth = max_depth if max_depth is not None else 2 self.use_async = use_async if use_async is not None else False self.extractor = extractor if extractor is not None else lambda x: x - metadata_extractor = ( - metadata_extractor - if metadata_extractor is not None - else _metadata_extractor - ) + if metadata_extractor is None: + self.metadata_extractor = _metadata_extractor + else: + self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor) self.autoset_encoding = autoset_encoding self.encoding = encoding - self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor) self.exclude_dirs = exclude_dirs if exclude_dirs is not None else () - - if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs): - raise ValueError( - f"Base url is included in exclude_dirs. Received base_url: {url} and " - f"exclude_dirs: {self.exclude_dirs}" - ) - self.timeout = timeout self.prevent_outside = prevent_outside if prevent_outside is not None else True self.link_regex = link_regex @@ -174,52 +176,101 @@ class RecursiveUrlLoader(BaseLoader): self.check_response_status = check_response_status self.continue_on_failure = continue_on_failure self.base_url = base_url if base_url is not None else url + self.retries = retries - def _get_child_links_recursive( + def _lazy_load_recursive( self, url: str, visited: Set[str], *, depth: int = 0 ) -> Iterator[Document]: - """Recursively get all child links starting with the path of the input URL. + if url in visited: + raise ValueError + visited.add(url) + response = None + for _ in range(self.retries + 1): + response = self._request(url) + if response: + break + if not response: + return + text = response.text + if content := self.extractor(text): + metadata = self.metadata_extractor(text, url, response) + yield Document(content, metadata=metadata) - Args: - url: The URL to crawl. - visited: A set of visited URLs. - depth: Current depth of recursion. Stop when depth >= max_depth. - """ + if depth + 1 < self.max_depth: + for link in self._extract_sub_links(text, url): + if link not in visited: + yield from self._lazy_load_recursive(link, visited, depth=depth + 1) + if link not in visited: + raise ValueError + async def _async_get_child_links_recursive( + self, + url: str, + visited: Set[str], + session: aiohttp.ClientSession, + *, + depth: int = 0, + ) -> AsyncIterator[Document]: if depth >= self.max_depth: return - # Get all links that can be accessed from the current URL visited.add(url) try: - response = requests.get(url, timeout=self.timeout, headers=self.headers) - - if self.encoding is not None: - response.encoding = self.encoding - elif self.autoset_encoding: - response.encoding = response.apparent_encoding - - if self.check_response_status and 400 <= response.status_code <= 599: - raise ValueError(f"Received HTTP status {response.status_code}") - except Exception as e: + async with session.get(url) as response: + if self.check_response_status: + response.raise_for_status() + text = await response.text() + except (aiohttp.client_exceptions.InvalidURL, Exception) as e: if self.continue_on_failure: logger.warning( - f"Unable to load from {url}. Received error {e} of type " + f"Unable to load {url}. Received error {e} of type " f"{e.__class__.__name__}" ) return else: raise e - content = self.extractor(response.text) - if content: - yield Document( - page_content=content, - metadata=self.metadata_extractor(response.text, url, response), - ) - # Store the visited links and recursively visit the children - sub_links = extract_sub_links( - response.text, + if content := self.extractor(text): + metadata = self.metadata_extractor(text, url, response) + yield Document(content, metadata=metadata) + for link in self._extract_sub_links(text, url): + if link not in visited: + async for doc in self._async_get_child_links_recursive( + link, visited, session, depth=depth + 1 + ): + yield doc + + def lazy_load(self) -> Iterator[Document]: + """Lazy load web pages. + When use_async is True, this function will not be lazy, + but it will still work in the expected way, just not lazy.""" + if self.use_async: + + async def aload(): + results = [] + async for doc in self.alazy_load(): + results.append(doc) + return results + + return iter(asyncio.run(aload())) + + else: + yield from self._lazy_load_recursive(self.url, set()) + + async def alazy_load(self) -> AsyncIterator[Document]: + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False), + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) as session: + async for doc in self._async_get_child_links_recursive( + self.url, set(), session + ): + yield doc + + def _extract_sub_links(self, html: str, url: str) -> List[str]: + return extract_sub_links( + html, url, base_url=self.base_url, pattern=self.link_regex, @@ -227,125 +278,31 @@ class RecursiveUrlLoader(BaseLoader): exclude_prefixes=self.exclude_dirs, continue_on_failure=self.continue_on_failure, ) - for link in sub_links: - # Check all unvisited links - if link not in visited: - yield from self._get_child_links_recursive( - link, visited, depth=depth + 1 - ) - - async def _async_get_child_links_recursive( - self, - url: str, - visited: Set[str], - *, - session: Optional[aiohttp.ClientSession] = None, - depth: int = 0, - ) -> List[Document]: - """Recursively get all child links starting with the path of the input URL. - - Args: - url: The URL to crawl. - visited: A set of visited URLs. - depth: To reach the current url, how many pages have been visited. - """ - if not self.use_async: - raise ValueError( - "Async functions forbidden when not initialized with `use_async`" - ) + def _request(self, url: str) -> Optional[requests.Response]: try: - import aiohttp - except ImportError: - raise ImportError( - "The aiohttp package is required for the RecursiveUrlLoader. " - "Please install it with `pip install aiohttp`." - ) - if depth >= self.max_depth: - return [] + response = requests.get(url, timeout=self.timeout, headers=self.headers) - # Disable SSL verification because websites may have invalid SSL certificates, - # but won't cause any security issues for us. - close_session = session is None - session = ( - session - if session is not None - else aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl=False), - timeout=aiohttp.ClientTimeout(total=self.timeout), - headers=self.headers, - ) - ) - visited.add(url) - try: - async with session.get(url) as response: - text = await response.text() - if self.check_response_status and 400 <= response.status <= 599: - raise ValueError(f"Received HTTP status {response.status}") - except (aiohttp.client_exceptions.InvalidURL, Exception) as e: - if close_session: - await session.close() + if self.encoding is not None: + response.encoding = self.encoding + elif self.autoset_encoding: + response.encoding = response.apparent_encoding + else: + pass + + if self.check_response_status: + response.raise_for_status() + except Exception as e: if self.continue_on_failure: logger.warning( f"Unable to load {url}. Received error {e} of type " f"{e.__class__.__name__}" ) - return [] + return None else: raise e - results = [] - content = self.extractor(text) - if content: - results.append( - Document( - page_content=content, - metadata=self.metadata_extractor(text, url, response), - ) - ) - if depth < self.max_depth - 1: - sub_links = extract_sub_links( - text, - url, - base_url=self.base_url, - pattern=self.link_regex, - prevent_outside=self.prevent_outside, - exclude_prefixes=self.exclude_dirs, - continue_on_failure=self.continue_on_failure, - ) - - # Recursively call the function to get the children of the children - sub_tasks = [] - to_visit = set(sub_links).difference(visited) - for link in to_visit: - sub_tasks.append( - self._async_get_child_links_recursive( - link, visited, session=session, depth=depth + 1 - ) - ) - next_results = await asyncio.gather(*sub_tasks) - for sub_result in next_results: - if isinstance(sub_result, Exception) or sub_result is None: - # We don't want to stop the whole process, so just ignore it - # Not standard html format or invalid url or 404 may cause this. - continue - # locking not fully working, temporary hack to ensure deduplication - results += [r for r in sub_result if r not in results] - if close_session: - await session.close() - return results - - def lazy_load(self) -> Iterator[Document]: - """Lazy load web pages. - When use_async is True, this function will not be lazy, - but it will still work in the expected way, just not lazy.""" - visited: Set[str] = set() - if self.use_async: - results = asyncio.run( - self._async_get_child_links_recursive(self.url, visited) - ) - return iter(results or []) else: - return self._get_child_links_recursive(self.url, visited) + return response _MetadataExtractorType1 = Callable[[str, str], dict] diff --git a/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py index 92c274f33ed..c0cec92db2b 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py +++ b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py @@ -1,7 +1,9 @@ +from datetime import datetime + from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader -def test_async_recursive_url_loader() -> None: +async def test_async_recursive_url_loader() -> None: url = "https://docs.python.org/3.9/" loader = RecursiveUrlLoader( url, @@ -11,7 +13,7 @@ def test_async_recursive_url_loader() -> None: timeout=None, check_response_status=True, ) - docs = loader.load() + docs = [document async for document in loader.alazy_load()] assert len(docs) == 512 assert docs[0].page_content == "placeholder" @@ -32,11 +34,22 @@ def test_async_recursive_url_loader_deterministic() -> None: def test_sync_recursive_url_loader() -> None: url = "https://docs.python.org/3.9/" loader = RecursiveUrlLoader( - url, extractor=lambda _: "placeholder", use_async=False, max_depth=2 + url, + extractor=lambda _: "placeholder", + use_async=False, + max_depth=3, + timeout=None, + check_response_status=True, + retries=10, ) - docs = loader.load() - assert len(docs) == 24 + docs = [document for document in loader.lazy_load()] + with open(f"/Users/bagatur/Desktop/docs_{datetime.now()}.txt", "w") as f: + f.write("\n".join(doc.metadata["source"] for doc in docs)) assert docs[0].page_content == "placeholder" + # no duplicates + deduped = [doc for i, doc in enumerate(docs) if doc not in docs[:i]] + assert len(docs) == len(deduped) + assert len(docs) == 512 def test_sync_async_equivalent() -> None: diff --git a/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py b/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py index 55e00d99765..67307356d3b 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py +++ b/libs/community/tests/unit_tests/document_loaders/test_recursive_url_loader.py @@ -11,19 +11,26 @@ import requests_mock from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader +fake_url = f"https://{uuid.uuid4()}.com" link_to_one_two = """
link_to_one
link_to_two
""" +link_to_two = '
link_to_two
' link_to_three = '
link_to_three
' -no_links = "

no links

" - -fake_url = f"https://{uuid.uuid4()}.com" +link_to_three_four_five = f""" +

link_to_three
+
link_to_four
+
link_to_five
+""" +link_to_index = f'
link_to_index
' URL_TO_HTML = { fake_url: link_to_one_two, f"{fake_url}/one": link_to_three, - f"{fake_url}/two": link_to_three, - f"{fake_url}/three": no_links, + f"{fake_url}/two": link_to_three_four_five, + f"{fake_url}/three": link_to_two, + f"{fake_url}/four": link_to_index, + f"{fake_url}/five/foo": link_to_three_four_five, } @@ -44,7 +51,7 @@ class MockGet: return self -@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)]) +@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 6)]) @pytest.mark.parametrize("use_async", [False, True]) def test_lazy_load( mocker: Any, max_depth: int, expected_docs: int, use_async: bool diff --git a/libs/core/langchain_core/utils/html.py b/libs/core/langchain_core/utils/html.py index 3e41c187b4a..2f9663aac97 100644 --- a/libs/core/langchain_core/utils/html.py +++ b/libs/core/langchain_core/utils/html.py @@ -1,3 +1,4 @@ +import importlib.util import logging import re from typing import List, Optional, Sequence, Union @@ -43,6 +44,11 @@ def find_all_links( Returns: List[str]: all links """ + if importlib.util.find_spec("bs4"): + from bs4 import BeautifulSoup + + soup = BeautifulSoup(raw_html) + return [tag["href"] for tag in soup.findAll('a', href=True)] pattern = pattern or DEFAULT_LINK_REGEX return list(set(re.findall(pattern, raw_html)))