diff --git a/libs/community/langchain_community/document_loaders/recursive_url_loader.py b/libs/community/langchain_community/document_loaders/recursive_url_loader.py index 6231c7af8d6..687b02ccd69 100644 --- a/libs/community/langchain_community/document_loaders/recursive_url_loader.py +++ b/libs/community/langchain_community/document_loaders/recursive_url_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations import asyncio +import inspect import logging import re from typing import ( - TYPE_CHECKING, Callable, Iterator, List, @@ -12,23 +12,25 @@ from typing import ( Sequence, Set, Union, + cast, ) +import aiohttp import requests from langchain_core.documents import Document from langchain_core.utils.html import extract_sub_links from langchain_community.document_loaders.base import BaseLoader -if TYPE_CHECKING: - import aiohttp - logger = logging.getLogger(__name__) -def _metadata_extractor(raw_html: str, url: str) -> dict: +def _metadata_extractor( + raw_html: str, url: str, response: Union[requests.Response, aiohttp.ClientResponse] +) -> dict: """Extract metadata from raw html using BeautifulSoup.""" - metadata = {"source": url} + content_type = getattr(response, "headers").get("Content-Type", "") + metadata = {"source": url, "content_type": content_type} try: from bs4 import BeautifulSoup @@ -86,7 +88,7 @@ class RecursiveUrlLoader(BaseLoader): max_depth: Optional[int] = 2, use_async: Optional[bool] = None, extractor: Optional[Callable[[str], str]] = None, - metadata_extractor: Optional[Callable[[str, str], dict]] = None, + metadata_extractor: Optional[_MetadataExtractorType] = None, exclude_dirs: Optional[Sequence[str]] = (), timeout: Optional[int] = 10, prevent_outside: bool = True, @@ -108,10 +110,22 @@ class RecursiveUrlLoader(BaseLoader): extractor: A function to extract document contents from raw html. When extract function returns an empty string, the document is ignored. - metadata_extractor: A function to extract metadata from raw html and the - source url (args in that order). Default extractor will attempt - to use BeautifulSoup4 to extract the title, description and language - of the page. + metadata_extractor: A function to extract metadata from args: raw html, the + source url, and the requests.Response/aiohttp.ClientResponse object + (args in that order). + Default extractor will attempt to use BeautifulSoup4 to extract the + title, description and language of the page. + ..code-block:: python + + import requests + import aiohttp + + def simple_metadata_extractor( + raw_html: str, url: str, response: Union[requests.Response, aiohttp.ClientResponse] + ) -> dict: + content_type = getattr(response, "headers").get("Content-Type", "") + return {"source": url, "content_type": content_type} + exclude_dirs: A list of subdirectories to exclude. timeout: The timeout for the requests, in the unit of seconds. If None then connection will not timeout. @@ -123,17 +137,18 @@ class RecursiveUrlLoader(BaseLoader): continue_on_failure: If True, continue if getting or parsing a link raises an exception. Otherwise, raise the exception. base_url: The base url to check for outside links against. - """ + """ # noqa: E501 self.url = url self.max_depth = max_depth if max_depth is not None else 2 self.use_async = use_async if use_async is not None else False self.extractor = extractor if extractor is not None else lambda x: x - self.metadata_extractor = ( + metadata_extractor = ( metadata_extractor if metadata_extractor is not None else _metadata_extractor ) + self.metadata_extractor = _wrap_metadata_extractor(metadata_extractor) self.exclude_dirs = exclude_dirs if exclude_dirs is not None else () if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs): @@ -184,7 +199,7 @@ class RecursiveUrlLoader(BaseLoader): if content: yield Document( page_content=content, - metadata=self.metadata_extractor(response.text, url), + metadata=self.metadata_extractor(response.text, url, response), ) # Store the visited links and recursively visit the children @@ -270,7 +285,7 @@ class RecursiveUrlLoader(BaseLoader): results.append( Document( page_content=content, - metadata=self.metadata_extractor(text, url), + metadata=self.metadata_extractor(text, url, response), ) ) if depth < self.max_depth - 1: @@ -318,3 +333,27 @@ class RecursiveUrlLoader(BaseLoader): return iter(results or []) else: return self._get_child_links_recursive(self.url, visited) + + +_MetadataExtractorType1 = Callable[[str, str], dict] +_MetadataExtractorType2 = Callable[ + [str, str, Union[requests.Response, aiohttp.ClientResponse]], dict +] +_MetadataExtractorType = Union[_MetadataExtractorType1, _MetadataExtractorType2] + + +def _wrap_metadata_extractor( + metadata_extractor: _MetadataExtractorType, +) -> _MetadataExtractorType2: + if len(inspect.signature(metadata_extractor).parameters) == 3: + return cast(_MetadataExtractorType2, metadata_extractor) + else: + + def _metadata_extractor_wrapper( + raw_html: str, + url: str, + response: Union[requests.Response, aiohttp.ClientResponse], + ) -> dict: + return cast(_MetadataExtractorType1, metadata_extractor)(raw_html, url) + + return _metadata_extractor_wrapper diff --git a/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py index ff8083bcd50..1f359932227 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py +++ b/libs/community/tests/integration_tests/document_loaders/test_recursive_url_loader.py @@ -35,7 +35,7 @@ def test_sync_recursive_url_loader() -> None: url, extractor=lambda _: "placeholder", use_async=False, max_depth=2 ) docs = loader.load() - assert len(docs) == 25 + assert len(docs) == 24 assert docs[0].page_content == "placeholder" @@ -55,3 +55,17 @@ def test_loading_invalid_url() -> None: ) docs = loader.load() assert len(docs) == 0 + + +def test_sync_async_metadata_necessary_properties() -> None: + url = "https://docs.python.org/3.9/" + loader = RecursiveUrlLoader(url, use_async=False, max_depth=2) + async_loader = RecursiveUrlLoader(url, use_async=False, max_depth=2) + docs = loader.load() + async_docs = async_loader.load() + for doc in docs: + assert "source" in doc.metadata + assert "content_type" in doc.metadata + for doc in async_docs: + assert "source" in doc.metadata + assert "content_type" in doc.metadata