mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
Compare commits
2 Commits
langchain-
...
bagatur/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8616e1c44a | ||
|
|
2e1dc2c660 |
@@ -5,6 +5,7 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
@@ -86,6 +87,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
self,
|
||||
url: str,
|
||||
max_depth: Optional[int] = 2,
|
||||
# TODO: Deprecate use_async
|
||||
use_async: Optional[bool] = None,
|
||||
extractor: Optional[Callable[[str], str]] = None,
|
||||
metadata_extractor: Optional[_MetadataExtractorType] = None,
|
||||
@@ -100,6 +102,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
base_url: Optional[str] = None,
|
||||
autoset_encoding: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
retries: int = 0,
|
||||
) -> None:
|
||||
"""Initialize with URL to crawl and any subdirectories to exclude.
|
||||
|
||||
@@ -147,26 +150,25 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
set to given value, regardless of the `autoset_encoding` argument.
|
||||
""" # noqa: E501
|
||||
|
||||
if exclude_dirs and any(url.startswith(dir) for dir in exclude_dirs):
|
||||
raise ValueError(
|
||||
f"Base url is included in exclude_dirs. Received base_url: {url} and "
|
||||
f"exclude_dirs: {exclude_dirs}"
|
||||
)
|
||||
if max_depth is not None and max_depth < 1:
|
||||
raise ValueError("max_depth must be positive.")
|
||||
|
||||
self.url = url
|
||||
self.max_depth = max_depth if max_depth is not None else 2
|
||||
self.use_async = use_async if use_async is not None else False
|
||||
self.extractor = extractor if extractor is not None else lambda x: x
|
||||
metadata_extractor = (
|
||||
metadata_extractor
|
||||
if metadata_extractor is not None
|
||||
else _metadata_extractor
|
||||
)
|
||||
if metadata_extractor is None:
|
||||
self.metadata_extractor = _metadata_extractor
|
||||
else:
|
||||
self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor)
|
||||
self.autoset_encoding = autoset_encoding
|
||||
self.encoding = encoding
|
||||
self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor)
|
||||
self.exclude_dirs = exclude_dirs if exclude_dirs is not None else ()
|
||||
|
||||
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
|
||||
raise ValueError(
|
||||
f"Base url is included in exclude_dirs. Received base_url: {url} and "
|
||||
f"exclude_dirs: {self.exclude_dirs}"
|
||||
)
|
||||
|
||||
self.timeout = timeout
|
||||
self.prevent_outside = prevent_outside if prevent_outside is not None else True
|
||||
self.link_regex = link_regex
|
||||
@@ -174,52 +176,102 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
self.check_response_status = check_response_status
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.base_url = base_url if base_url is not None else url
|
||||
self.retries = retries
|
||||
|
||||
def _get_child_links_recursive(
|
||||
def _lazy_load_recursive(
|
||||
self, url: str, visited: Set[str], *, depth: int = 0
|
||||
) -> Iterator[Document]:
|
||||
"""Recursively get all child links starting with the path of the input URL.
|
||||
if url in visited:
|
||||
raise ValueError
|
||||
visited.add(url)
|
||||
response = None
|
||||
for _ in range(self.retries + 1):
|
||||
response = self._request(url)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
return
|
||||
text = response.text
|
||||
if content := self.extractor(text):
|
||||
metadata = self.metadata_extractor(text, url, response)
|
||||
yield Document(content, metadata=metadata)
|
||||
|
||||
Args:
|
||||
url: The URL to crawl.
|
||||
visited: A set of visited URLs.
|
||||
depth: Current depth of recursion. Stop when depth >= max_depth.
|
||||
"""
|
||||
if depth + 1 < self.max_depth:
|
||||
for link in self._extract_sub_links(text, url):
|
||||
if link not in visited:
|
||||
for doc in self._lazy_load_recursive(link, visited, depth=depth + 1):
|
||||
yield doc
|
||||
if link not in visited:
|
||||
raise ValueError
|
||||
|
||||
async def _async_get_child_links_recursive(
|
||||
self,
|
||||
url: str,
|
||||
visited: Set[str],
|
||||
session: aiohttp.ClientSession,
|
||||
*,
|
||||
depth: int = 0,
|
||||
) -> AsyncIterator[Document]:
|
||||
if depth >= self.max_depth:
|
||||
return
|
||||
|
||||
# Get all links that can be accessed from the current URL
|
||||
visited.add(url)
|
||||
try:
|
||||
response = requests.get(url, timeout=self.timeout, headers=self.headers)
|
||||
|
||||
if self.encoding is not None:
|
||||
response.encoding = self.encoding
|
||||
elif self.autoset_encoding:
|
||||
response.encoding = response.apparent_encoding
|
||||
|
||||
if self.check_response_status and 400 <= response.status_code <= 599:
|
||||
raise ValueError(f"Received HTTP status {response.status_code}")
|
||||
except Exception as e:
|
||||
async with session.get(url) as response:
|
||||
if self.check_response_status:
|
||||
response.raise_for_status()
|
||||
text = await response.text()
|
||||
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
|
||||
if self.continue_on_failure:
|
||||
logger.warning(
|
||||
f"Unable to load from {url}. Received error {e} of type "
|
||||
f"Unable to load {url}. Received error {e} of type "
|
||||
f"{e.__class__.__name__}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
content = self.extractor(response.text)
|
||||
if content:
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=self.metadata_extractor(response.text, url, response),
|
||||
)
|
||||
|
||||
# Store the visited links and recursively visit the children
|
||||
sub_links = extract_sub_links(
|
||||
response.text,
|
||||
if content := self.extractor(text):
|
||||
metadata = self.metadata_extractor(text, url, response)
|
||||
yield Document(content, metadata=metadata)
|
||||
for link in self._extract_sub_links(text, url):
|
||||
if link not in visited:
|
||||
async for doc in self._async_get_child_links_recursive(
|
||||
link, visited, session, depth=depth + 1
|
||||
):
|
||||
yield doc
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load web pages.
|
||||
When use_async is True, this function will not be lazy,
|
||||
but it will still work in the expected way, just not lazy."""
|
||||
if self.use_async:
|
||||
|
||||
async def aload():
|
||||
results = []
|
||||
async for doc in self.alazy_load():
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
return iter(asyncio.run(aload()))
|
||||
|
||||
else:
|
||||
yield from self._lazy_load_recursive(self.url, set())
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
async with aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
headers=self.headers,
|
||||
) as session:
|
||||
async for doc in self._async_get_child_links_recursive(
|
||||
self.url, set(), session
|
||||
):
|
||||
yield doc
|
||||
|
||||
def _extract_sub_links(self, html: str, url: str) -> List[str]:
|
||||
return extract_sub_links(
|
||||
html,
|
||||
url,
|
||||
base_url=self.base_url,
|
||||
pattern=self.link_regex,
|
||||
@@ -227,125 +279,31 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
exclude_prefixes=self.exclude_dirs,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
for link in sub_links:
|
||||
# Check all unvisited links
|
||||
if link not in visited:
|
||||
yield from self._get_child_links_recursive(
|
||||
link, visited, depth=depth + 1
|
||||
)
|
||||
|
||||
async def _async_get_child_links_recursive(
|
||||
self,
|
||||
url: str,
|
||||
visited: Set[str],
|
||||
*,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
depth: int = 0,
|
||||
) -> List[Document]:
|
||||
"""Recursively get all child links starting with the path of the input URL.
|
||||
|
||||
Args:
|
||||
url: The URL to crawl.
|
||||
visited: A set of visited URLs.
|
||||
depth: To reach the current url, how many pages have been visited.
|
||||
"""
|
||||
if not self.use_async:
|
||||
raise ValueError(
|
||||
"Async functions forbidden when not initialized with `use_async`"
|
||||
)
|
||||
|
||||
def _request(self, url: str) -> Optional[requests.Response]:
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The aiohttp package is required for the RecursiveUrlLoader. "
|
||||
"Please install it with `pip install aiohttp`."
|
||||
)
|
||||
if depth >= self.max_depth:
|
||||
return []
|
||||
response = requests.get(url, timeout=self.timeout, headers=self.headers)
|
||||
|
||||
# Disable SSL verification because websites may have invalid SSL certificates,
|
||||
# but won't cause any security issues for us.
|
||||
close_session = session is None
|
||||
session = (
|
||||
session
|
||||
if session is not None
|
||||
else aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
headers=self.headers,
|
||||
)
|
||||
)
|
||||
visited.add(url)
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
text = await response.text()
|
||||
if self.check_response_status and 400 <= response.status <= 599:
|
||||
raise ValueError(f"Received HTTP status {response.status}")
|
||||
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
|
||||
if close_session:
|
||||
await session.close()
|
||||
if self.encoding is not None:
|
||||
response.encoding = self.encoding
|
||||
elif self.autoset_encoding:
|
||||
response.encoding = response.apparent_encoding
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.check_response_status:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
logger.warning(
|
||||
f"Unable to load {url}. Received error {e} of type "
|
||||
f"{e.__class__.__name__}"
|
||||
)
|
||||
return []
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
results = []
|
||||
content = self.extractor(text)
|
||||
if content:
|
||||
results.append(
|
||||
Document(
|
||||
page_content=content,
|
||||
metadata=self.metadata_extractor(text, url, response),
|
||||
)
|
||||
)
|
||||
if depth < self.max_depth - 1:
|
||||
sub_links = extract_sub_links(
|
||||
text,
|
||||
url,
|
||||
base_url=self.base_url,
|
||||
pattern=self.link_regex,
|
||||
prevent_outside=self.prevent_outside,
|
||||
exclude_prefixes=self.exclude_dirs,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
|
||||
# Recursively call the function to get the children of the children
|
||||
sub_tasks = []
|
||||
to_visit = set(sub_links).difference(visited)
|
||||
for link in to_visit:
|
||||
sub_tasks.append(
|
||||
self._async_get_child_links_recursive(
|
||||
link, visited, session=session, depth=depth + 1
|
||||
)
|
||||
)
|
||||
next_results = await asyncio.gather(*sub_tasks)
|
||||
for sub_result in next_results:
|
||||
if isinstance(sub_result, Exception) or sub_result is None:
|
||||
# We don't want to stop the whole process, so just ignore it
|
||||
# Not standard html format or invalid url or 404 may cause this.
|
||||
continue
|
||||
# locking not fully working, temporary hack to ensure deduplication
|
||||
results += [r for r in sub_result if r not in results]
|
||||
if close_session:
|
||||
await session.close()
|
||||
return results
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load web pages.
|
||||
When use_async is True, this function will not be lazy,
|
||||
but it will still work in the expected way, just not lazy."""
|
||||
visited: Set[str] = set()
|
||||
if self.use_async:
|
||||
results = asyncio.run(
|
||||
self._async_get_child_links_recursive(self.url, visited)
|
||||
)
|
||||
return iter(results or [])
|
||||
else:
|
||||
return self._get_child_links_recursive(self.url, visited)
|
||||
return response
|
||||
|
||||
|
||||
_MetadataExtractorType1 = Callable[[str, str], dict]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||
|
||||
|
||||
def test_async_recursive_url_loader() -> None:
|
||||
async def test_async_recursive_url_loader() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
loader = RecursiveUrlLoader(
|
||||
url,
|
||||
@@ -11,7 +13,7 @@ def test_async_recursive_url_loader() -> None:
|
||||
timeout=None,
|
||||
check_response_status=True,
|
||||
)
|
||||
docs = loader.load()
|
||||
docs = [document async for document in loader.alazy_load()]
|
||||
assert len(docs) == 512
|
||||
assert docs[0].page_content == "placeholder"
|
||||
|
||||
@@ -30,13 +32,24 @@ def test_async_recursive_url_loader_deterministic() -> None:
|
||||
|
||||
|
||||
def test_sync_recursive_url_loader() -> None:
|
||||
url = "https://docs.python.org/3.9/"
|
||||
url = "https://python.langchain.com/"
|
||||
loader = RecursiveUrlLoader(
|
||||
url, extractor=lambda _: "placeholder", use_async=False, max_depth=2
|
||||
url,
|
||||
extractor=lambda _: "placeholder",
|
||||
use_async=False,
|
||||
max_depth=3,
|
||||
timeout=None,
|
||||
check_response_status=True,
|
||||
retries=10,
|
||||
)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 24
|
||||
docs = [document for document in loader.lazy_load()]
|
||||
with open(f"/Users/bagatur/Desktop/docs_{datetime.now()}.txt", "w") as f:
|
||||
f.write("\n".join(doc.metadata["source"] for doc in docs))
|
||||
assert docs[0].page_content == "placeholder"
|
||||
# no duplicates
|
||||
deduped = [doc for i, doc in enumerate(docs) if doc not in docs[:i]]
|
||||
assert len(docs) == len(deduped)
|
||||
assert len(docs) == 512
|
||||
|
||||
|
||||
def test_sync_async_equivalent() -> None:
|
||||
|
||||
@@ -11,19 +11,26 @@ import requests_mock
|
||||
|
||||
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||
|
||||
fake_url = f"https://{uuid.uuid4()}.com"
|
||||
link_to_one_two = """
|
||||
<div><a href="/one">link_to_one</a></div>
|
||||
<div><a href="/two">link_to_two</a></div>
|
||||
"""
|
||||
link_to_two = '<div><a href="../two">link_to_two</a></div>'
|
||||
link_to_three = '<div><a href="../three">link_to_three</a></div>'
|
||||
no_links = "<p>no links<p>"
|
||||
|
||||
fake_url = f"https://{uuid.uuid4()}.com"
|
||||
link_to_three_four_five = f"""
|
||||
<div><a href="{fake_url}/three">link_to_three</a></div>
|
||||
<div><a href="../four">link_to_four</a></div>
|
||||
<div><a href="../five/foo">link_to_five</a></div>
|
||||
"""
|
||||
link_to_index = f'<div><a href="{fake_url}">link_to_index</a></div>'
|
||||
URL_TO_HTML = {
|
||||
fake_url: link_to_one_two,
|
||||
f"{fake_url}/one": link_to_three,
|
||||
f"{fake_url}/two": link_to_three,
|
||||
f"{fake_url}/three": no_links,
|
||||
f"{fake_url}/two": link_to_three_four_five,
|
||||
f"{fake_url}/three": link_to_two,
|
||||
f"{fake_url}/four": link_to_index,
|
||||
f"{fake_url}/five/foo": link_to_three_four_five,
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +51,7 @@ class MockGet:
|
||||
return self
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 4)])
|
||||
@pytest.mark.parametrize(("max_depth", "expected_docs"), [(1, 1), (2, 3), (3, 6)])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
def test_lazy_load(
|
||||
mocker: Any, max_depth: int, expected_docs: int, use_async: bool
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Sequence, Union
|
||||
@@ -43,6 +44,11 @@ def find_all_links(
|
||||
Returns:
|
||||
List[str]: all links
|
||||
"""
|
||||
if importlib.util.find_spec("bs4"):
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
soup = BeautifulSoup(raw_html)
|
||||
return [tag["href"] for tag in soup.findAll('a', href=True)]
|
||||
pattern = pattern or DEFAULT_LINK_REGEX
|
||||
return list(set(re.findall(pattern, raw_html)))
|
||||
|
||||
@@ -115,4 +121,4 @@ def extract_sub_links(
|
||||
continue
|
||||
|
||||
results.append(path)
|
||||
return results
|
||||
return sorted(results)
|
||||
|
||||
Reference in New Issue
Block a user