From bb284eebe4d3aab6b1884c4ddfb31861456810bd Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 6 Mar 2024 19:15:24 +0100 Subject: [PATCH] Merge pull request #18436 * Implement lazy_load() for ConfluenceLoader --- .../document_loaders/confluence.py | 194 +++++++++++------- .../document_loaders/test_confluence.py | 27 ++- 2 files changed, 132 insertions(+), 89 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/confluence.py b/libs/community/langchain_community/document_loaders/confluence.py index ca7060a286d..e6445fb2795 100644 --- a/libs/community/langchain_community/document_loaders/confluence.py +++ b/libs/community/langchain_community/document_loaders/confluence.py @@ -1,7 +1,7 @@ import logging from enum import Enum from io import BytesIO -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import requests from langchain_core.documents import Document @@ -49,7 +49,7 @@ class ConfluenceLoader(BaseLoader): Confluence API supports difference format of page content. The storage format is the raw XML representation for storage. The view format is the HTML representation for viewing with macros are rendered as though it is viewed by users. You can pass - a enum `content_format` argument to `load()` to specify the content format, this is + a enum `content_format` argument to specify the content format, this is set to `ContentFormat.STORAGE` by default, the supported values are: `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`, `ContentFormat.ANONYMOUS_EXPORT_VIEW`, `ContentFormat.STORAGE`, @@ -66,18 +66,22 @@ class ConfluenceLoader(BaseLoader): loader = ConfluenceLoader( url="https://yoursite.atlassian.com/wiki", username="me", - api_key="12345" + api_key="12345", + space_key="SPACE", + limit=50, ) - documents = loader.load(space_key="SPACE",limit=50) + documents = loader.load() # Server on perm loader = ConfluenceLoader( url="https://confluence.yoursite.com/", username="me", api_key="your_password", - cloud=False + cloud=False, + space_key="SPACE", + limit=50, ) - documents = loader.load(space_key="SPACE",limit=50) + documents = loader.load() :param url: _description_ :type url: str @@ -99,6 +103,43 @@ class ConfluenceLoader(BaseLoader): :type max_retry_seconds: Optional[int], optional :param confluence_kwargs: additional kwargs to initialize confluence with :type confluence_kwargs: dict, optional + :param space_key: Space key retrieved from a confluence URL, defaults to None + :type space_key: Optional[str], optional + :param page_ids: List of specific page IDs to load, defaults to None + :type page_ids: Optional[List[str]], optional + :param label: Get all pages with this label, defaults to None + :type label: Optional[str], optional + :param cql: CQL Expression, defaults to None + :type cql: Optional[str], optional + :param include_restricted_content: defaults to False + :type include_restricted_content: bool, optional + :param include_archived_content: Whether to include archived content, + defaults to False + :type include_archived_content: bool, optional + :param include_attachments: defaults to False + :type include_attachments: bool, optional + :param include_comments: defaults to False + :type include_comments: bool, optional + :param content_format: Specify content format, defaults to + ContentFormat.STORAGE, the supported values are: + `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`, + `ContentFormat.ANONYMOUS_EXPORT_VIEW`, + `ContentFormat.STORAGE`, and `ContentFormat.VIEW`. + :type content_format: ContentFormat + :param limit: Maximum number of pages to retrieve per request, defaults to 50 + :type limit: int, optional + :param max_pages: Maximum number of pages to retrieve in total, defaults 1000 + :type max_pages: int, optional + :param ocr_languages: The languages to use for the Tesseract agent. To use a + language, you'll first need to install the appropriate + Tesseract language pack. + :type ocr_languages: str, optional + :param keep_markdown_format: Whether to keep the markdown format, defaults to + False + :type keep_markdown_format: bool + :param keep_newlines: Whether to keep the newlines format, defaults to + False + :type keep_newlines: bool :raises ValueError: Errors while validating input :raises ImportError: Required dependencies not installed. """ @@ -116,7 +157,37 @@ class ConfluenceLoader(BaseLoader): min_retry_seconds: Optional[int] = 2, max_retry_seconds: Optional[int] = 10, confluence_kwargs: Optional[dict] = None, + *, + space_key: Optional[str] = None, + page_ids: Optional[List[str]] = None, + label: Optional[str] = None, + cql: Optional[str] = None, + include_restricted_content: bool = False, + include_archived_content: bool = False, + include_attachments: bool = False, + include_comments: bool = False, + content_format: ContentFormat = ContentFormat.STORAGE, + limit: Optional[int] = 50, + max_pages: Optional[int] = 1000, + ocr_languages: Optional[str] = None, + keep_markdown_format: bool = False, + keep_newlines: bool = False, ): + self.space_key = space_key + self.page_ids = page_ids + self.label = label + self.cql = cql + self.include_restricted_content = include_restricted_content + self.include_archived_content = include_archived_content + self.include_attachments = include_attachments + self.include_comments = include_comments + self.content_format = content_format + self.limit = limit + self.max_pages = max_pages + self.ocr_languages = ocr_languages + self.keep_markdown_format = keep_markdown_format + self.keep_newlines = keep_newlines + confluence_kwargs = confluence_kwargs or {} errors = ConfluenceLoader.validate_init_args( url=url, @@ -204,74 +275,40 @@ class ConfluenceLoader(BaseLoader): ) return errors or None - def load( - self, - space_key: Optional[str] = None, - page_ids: Optional[List[str]] = None, - label: Optional[str] = None, - cql: Optional[str] = None, - include_restricted_content: bool = False, - include_archived_content: bool = False, - include_attachments: bool = False, - include_comments: bool = False, - content_format: ContentFormat = ContentFormat.STORAGE, - limit: Optional[int] = 50, - max_pages: Optional[int] = 1000, - ocr_languages: Optional[str] = None, - keep_markdown_format: bool = False, - keep_newlines: bool = False, - ) -> List[Document]: - """ - :param space_key: Space key retrieved from a confluence URL, defaults to None - :type space_key: Optional[str], optional - :param page_ids: List of specific page IDs to load, defaults to None - :type page_ids: Optional[List[str]], optional - :param label: Get all pages with this label, defaults to None - :type label: Optional[str], optional - :param cql: CQL Expression, defaults to None - :type cql: Optional[str], optional - :param include_restricted_content: defaults to False - :type include_restricted_content: bool, optional - :param include_archived_content: Whether to include archived content, - defaults to False - :type include_archived_content: bool, optional - :param include_attachments: defaults to False - :type include_attachments: bool, optional - :param include_comments: defaults to False - :type include_comments: bool, optional - :param content_format: Specify content format, defaults to - ContentFormat.STORAGE, the supported values are: - `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`, - `ContentFormat.ANONYMOUS_EXPORT_VIEW`, - `ContentFormat.STORAGE`, and `ContentFormat.VIEW`. - :type content_format: ContentFormat - :param limit: Maximum number of pages to retrieve per request, defaults to 50 - :type limit: int, optional - :param max_pages: Maximum number of pages to retrieve in total, defaults 1000 - :type max_pages: int, optional - :param ocr_languages: The languages to use for the Tesseract agent. To use a - language, you'll first need to install the appropriate - Tesseract language pack. - :type ocr_languages: str, optional - :param keep_markdown_format: Whether to keep the markdown format, defaults to - False - :type keep_markdown_format: bool - :param keep_newlines: Whether to keep the newlines format, defaults to - False - :type keep_newlines: bool - :raises ValueError: _description_ - :raises ImportError: _description_ - :return: _description_ - :rtype: List[Document] - """ + def _resolve_param(self, param_name: str, kwargs: Any) -> Any: + return kwargs[param_name] if param_name in kwargs else getattr(self, param_name) + + def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: + if kwargs: + logger.warning( + f"Received runtime arguments {kwargs}. Passing runtime args to `load`" + f" is deprecated. Please pass arguments during initialization instead." + ) + space_key = self._resolve_param("space_key", kwargs) + page_ids = self._resolve_param("page_ids", kwargs) + label = self._resolve_param("label", kwargs) + cql = self._resolve_param("cql", kwargs) + include_restricted_content = self._resolve_param( + "include_restricted_content", kwargs + ) + include_archived_content = self._resolve_param( + "include_archived_content", kwargs + ) + include_attachments = self._resolve_param("include_attachments", kwargs) + include_comments = self._resolve_param("include_comments", kwargs) + content_format = self._resolve_param("content_format", kwargs) + limit = self._resolve_param("limit", kwargs) + max_pages = self._resolve_param("max_pages", kwargs) + ocr_languages = self._resolve_param("ocr_languages", kwargs) + keep_markdown_format = self._resolve_param("keep_markdown_format", kwargs) + keep_newlines = self._resolve_param("keep_newlines", kwargs) + if not space_key and not page_ids and not label and not cql: raise ValueError( "Must specify at least one among `space_key`, `page_ids`, " "`label`, `cql` parameters." ) - docs = [] - if space_key: pages = self.paginate_request( self.confluence.get_all_pages_from_space, @@ -281,7 +318,7 @@ class ConfluenceLoader(BaseLoader): status="any" if include_archived_content else "current", expand=content_format.value, ) - docs += self.process_pages( + yield from self.process_pages( pages, include_restricted_content, include_attachments, @@ -314,7 +351,7 @@ class ConfluenceLoader(BaseLoader): include_archived_spaces=include_archived_content, expand=content_format.value, ) - docs += self.process_pages( + yield from self.process_pages( pages, include_restricted_content, include_attachments, @@ -343,7 +380,7 @@ class ConfluenceLoader(BaseLoader): ) if not include_restricted_content and not self.is_public_page(page): continue - doc = self.process_page( + yield self.process_page( page, include_attachments, include_comments, @@ -351,9 +388,12 @@ class ConfluenceLoader(BaseLoader): ocr_languages, keep_markdown_format, ) - docs.append(doc) - return docs + def load(self, **kwargs: Any) -> List[Document]: + return list(self._lazy_load(**kwargs)) + + def lazy_load(self) -> Iterator[Document]: + yield from self._lazy_load() def _search_content_by_cql( self, cql: str, include_archived_spaces: Optional[bool] = None, **kwargs: Any @@ -430,13 +470,12 @@ class ConfluenceLoader(BaseLoader): ocr_languages: Optional[str] = None, keep_markdown_format: Optional[bool] = False, keep_newlines: bool = False, - ) -> List[Document]: + ) -> Iterator[Document]: """Process a list of pages into a list of documents.""" - docs = [] for page in pages: if not include_restricted_content and not self.is_public_page(page): continue - doc = self.process_page( + yield self.process_page( page, include_attachments, include_comments, @@ -445,9 +484,6 @@ class ConfluenceLoader(BaseLoader): keep_markdown_format=keep_markdown_format, keep_newlines=keep_newlines, ) - docs.append(doc) - - return docs def process_page( self, diff --git a/libs/community/tests/unit_tests/document_loaders/test_confluence.py b/libs/community/tests/unit_tests/document_loaders/test_confluence.py index 38c474a3e9b..e6819a48a88 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_confluence.py +++ b/libs/community/tests/unit_tests/document_loaders/test_confluence.py @@ -1,5 +1,5 @@ import unittest -from typing import Dict +from typing import Any, Dict from unittest.mock import MagicMock, patch import pytest @@ -108,10 +108,12 @@ class TestConfluenceLoader: self._get_mock_page_restrictions("456"), ] - confluence_loader = self._get_mock_confluence_loader(mock_confluence) - mock_page_ids = ["123", "456"] - documents = confluence_loader.load(page_ids=mock_page_ids) + confluence_loader = self._get_mock_confluence_loader( + mock_confluence, page_ids=mock_page_ids + ) + + documents = list(confluence_loader.lazy_load()) assert mock_confluence.get_page_by_id.call_count == 2 assert mock_confluence.get_all_restrictions_for_content.call_count == 2 @@ -139,9 +141,11 @@ class TestConfluenceLoader: self._get_mock_page_restrictions("456"), ] - confluence_loader = self._get_mock_confluence_loader(mock_confluence) + confluence_loader = self._get_mock_confluence_loader( + mock_confluence, space_key=self.MOCK_SPACE_KEY, max_pages=2 + ) - documents = confluence_loader.load(space_key=self.MOCK_SPACE_KEY, max_pages=2) + documents = confluence_loader.load() assert mock_confluence.get_all_pages_from_space.call_count == 1 @@ -155,6 +159,7 @@ class TestConfluenceLoader: assert mock_confluence.cql.call_count == 0 assert mock_confluence.get_page_child_by_type.call_count == 0 + @pytest.mark.requires("markdownify") def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled( self, mock_confluence: MagicMock ) -> None: @@ -168,15 +173,16 @@ class TestConfluenceLoader: self._get_mock_page_restrictions("456"), ] - confluence_loader = self._get_mock_confluence_loader(mock_confluence) - - documents = confluence_loader.load( + confluence_loader = self._get_mock_confluence_loader( + mock_confluence, space_key=self.MOCK_SPACE_KEY, content_format=ContentFormat.VIEW, keep_markdown_format=True, max_pages=2, ) + documents = confluence_loader.load() + assert mock_confluence.get_all_pages_from_space.call_count == 1 assert len(documents) == 2 @@ -190,12 +196,13 @@ class TestConfluenceLoader: assert mock_confluence.get_page_child_by_type.call_count == 0 def _get_mock_confluence_loader( - self, mock_confluence: MagicMock + self, mock_confluence: MagicMock, **kwargs: Any ) -> ConfluenceLoader: confluence_loader = ConfluenceLoader( self.CONFLUENCE_URL, username=self.MOCK_USERNAME, api_key=self.MOCK_API_TOKEN, + **kwargs, ) confluence_loader.confluence = mock_confluence return confluence_loader