Merge pull request #18436

* Implement lazy_load() for ConfluenceLoader
This commit is contained in:
Christophe Bornet 2024-03-06 19:15:24 +01:00 committed by GitHub
parent 691480f491
commit bb284eebe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 132 additions and 89 deletions

View File

@ -1,7 +1,7 @@
import logging
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import requests
from langchain_core.documents import Document
@ -49,7 +49,7 @@ class ConfluenceLoader(BaseLoader):
Confluence API supports difference format of page content. The storage format is the
raw XML representation for storage. The view format is the HTML representation for
viewing with macros are rendered as though it is viewed by users. You can pass
a enum `content_format` argument to `load()` to specify the content format, this is
a enum `content_format` argument to specify the content format, this is
set to `ContentFormat.STORAGE` by default, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`, `ContentFormat.STORAGE`,
@ -66,18 +66,22 @@ class ConfluenceLoader(BaseLoader):
loader = ConfluenceLoader(
url="https://yoursite.atlassian.com/wiki",
username="me",
api_key="12345"
api_key="12345",
space_key="SPACE",
limit=50,
)
documents = loader.load(space_key="SPACE",limit=50)
documents = loader.load()
# Server on perm
loader = ConfluenceLoader(
url="https://confluence.yoursite.com/",
username="me",
api_key="your_password",
cloud=False
cloud=False,
space_key="SPACE",
limit=50,
)
documents = loader.load(space_key="SPACE",limit=50)
documents = loader.load()
:param url: _description_
:type url: str
@ -99,6 +103,43 @@ class ConfluenceLoader(BaseLoader):
:type max_retry_seconds: Optional[int], optional
:param confluence_kwargs: additional kwargs to initialize confluence with
:type confluence_kwargs: dict, optional
:param space_key: Space key retrieved from a confluence URL, defaults to None
:type space_key: Optional[str], optional
:param page_ids: List of specific page IDs to load, defaults to None
:type page_ids: Optional[List[str]], optional
:param label: Get all pages with this label, defaults to None
:type label: Optional[str], optional
:param cql: CQL Expression, defaults to None
:type cql: Optional[str], optional
:param include_restricted_content: defaults to False
:type include_restricted_content: bool, optional
:param include_archived_content: Whether to include archived content,
defaults to False
:type include_archived_content: bool, optional
:param include_attachments: defaults to False
:type include_attachments: bool, optional
:param include_comments: defaults to False
:type include_comments: bool, optional
:param content_format: Specify content format, defaults to
ContentFormat.STORAGE, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`,
`ContentFormat.STORAGE`, and `ContentFormat.VIEW`.
:type content_format: ContentFormat
:param limit: Maximum number of pages to retrieve per request, defaults to 50
:type limit: int, optional
:param max_pages: Maximum number of pages to retrieve in total, defaults 1000
:type max_pages: int, optional
:param ocr_languages: The languages to use for the Tesseract agent. To use a
language, you'll first need to install the appropriate
Tesseract language pack.
:type ocr_languages: str, optional
:param keep_markdown_format: Whether to keep the markdown format, defaults to
False
:type keep_markdown_format: bool
:param keep_newlines: Whether to keep the newlines format, defaults to
False
:type keep_newlines: bool
:raises ValueError: Errors while validating input
:raises ImportError: Required dependencies not installed.
"""
@ -116,7 +157,37 @@ class ConfluenceLoader(BaseLoader):
min_retry_seconds: Optional[int] = 2,
max_retry_seconds: Optional[int] = 10,
confluence_kwargs: Optional[dict] = None,
*,
space_key: Optional[str] = None,
page_ids: Optional[List[str]] = None,
label: Optional[str] = None,
cql: Optional[str] = None,
include_restricted_content: bool = False,
include_archived_content: bool = False,
include_attachments: bool = False,
include_comments: bool = False,
content_format: ContentFormat = ContentFormat.STORAGE,
limit: Optional[int] = 50,
max_pages: Optional[int] = 1000,
ocr_languages: Optional[str] = None,
keep_markdown_format: bool = False,
keep_newlines: bool = False,
):
self.space_key = space_key
self.page_ids = page_ids
self.label = label
self.cql = cql
self.include_restricted_content = include_restricted_content
self.include_archived_content = include_archived_content
self.include_attachments = include_attachments
self.include_comments = include_comments
self.content_format = content_format
self.limit = limit
self.max_pages = max_pages
self.ocr_languages = ocr_languages
self.keep_markdown_format = keep_markdown_format
self.keep_newlines = keep_newlines
confluence_kwargs = confluence_kwargs or {}
errors = ConfluenceLoader.validate_init_args(
url=url,
@ -204,74 +275,40 @@ class ConfluenceLoader(BaseLoader):
)
return errors or None
def load(
self,
space_key: Optional[str] = None,
page_ids: Optional[List[str]] = None,
label: Optional[str] = None,
cql: Optional[str] = None,
include_restricted_content: bool = False,
include_archived_content: bool = False,
include_attachments: bool = False,
include_comments: bool = False,
content_format: ContentFormat = ContentFormat.STORAGE,
limit: Optional[int] = 50,
max_pages: Optional[int] = 1000,
ocr_languages: Optional[str] = None,
keep_markdown_format: bool = False,
keep_newlines: bool = False,
) -> List[Document]:
"""
:param space_key: Space key retrieved from a confluence URL, defaults to None
:type space_key: Optional[str], optional
:param page_ids: List of specific page IDs to load, defaults to None
:type page_ids: Optional[List[str]], optional
:param label: Get all pages with this label, defaults to None
:type label: Optional[str], optional
:param cql: CQL Expression, defaults to None
:type cql: Optional[str], optional
:param include_restricted_content: defaults to False
:type include_restricted_content: bool, optional
:param include_archived_content: Whether to include archived content,
defaults to False
:type include_archived_content: bool, optional
:param include_attachments: defaults to False
:type include_attachments: bool, optional
:param include_comments: defaults to False
:type include_comments: bool, optional
:param content_format: Specify content format, defaults to
ContentFormat.STORAGE, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`,
`ContentFormat.STORAGE`, and `ContentFormat.VIEW`.
:type content_format: ContentFormat
:param limit: Maximum number of pages to retrieve per request, defaults to 50
:type limit: int, optional
:param max_pages: Maximum number of pages to retrieve in total, defaults 1000
:type max_pages: int, optional
:param ocr_languages: The languages to use for the Tesseract agent. To use a
language, you'll first need to install the appropriate
Tesseract language pack.
:type ocr_languages: str, optional
:param keep_markdown_format: Whether to keep the markdown format, defaults to
False
:type keep_markdown_format: bool
:param keep_newlines: Whether to keep the newlines format, defaults to
False
:type keep_newlines: bool
:raises ValueError: _description_
:raises ImportError: _description_
:return: _description_
:rtype: List[Document]
"""
def _resolve_param(self, param_name: str, kwargs: Any) -> Any:
return kwargs[param_name] if param_name in kwargs else getattr(self, param_name)
def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
if kwargs:
logger.warning(
f"Received runtime arguments {kwargs}. Passing runtime args to `load`"
f" is deprecated. Please pass arguments during initialization instead."
)
space_key = self._resolve_param("space_key", kwargs)
page_ids = self._resolve_param("page_ids", kwargs)
label = self._resolve_param("label", kwargs)
cql = self._resolve_param("cql", kwargs)
include_restricted_content = self._resolve_param(
"include_restricted_content", kwargs
)
include_archived_content = self._resolve_param(
"include_archived_content", kwargs
)
include_attachments = self._resolve_param("include_attachments", kwargs)
include_comments = self._resolve_param("include_comments", kwargs)
content_format = self._resolve_param("content_format", kwargs)
limit = self._resolve_param("limit", kwargs)
max_pages = self._resolve_param("max_pages", kwargs)
ocr_languages = self._resolve_param("ocr_languages", kwargs)
keep_markdown_format = self._resolve_param("keep_markdown_format", kwargs)
keep_newlines = self._resolve_param("keep_newlines", kwargs)
if not space_key and not page_ids and not label and not cql:
raise ValueError(
"Must specify at least one among `space_key`, `page_ids`, "
"`label`, `cql` parameters."
)
docs = []
if space_key:
pages = self.paginate_request(
self.confluence.get_all_pages_from_space,
@ -281,7 +318,7 @@ class ConfluenceLoader(BaseLoader):
status="any" if include_archived_content else "current",
expand=content_format.value,
)
docs += self.process_pages(
yield from self.process_pages(
pages,
include_restricted_content,
include_attachments,
@ -314,7 +351,7 @@ class ConfluenceLoader(BaseLoader):
include_archived_spaces=include_archived_content,
expand=content_format.value,
)
docs += self.process_pages(
yield from self.process_pages(
pages,
include_restricted_content,
include_attachments,
@ -343,7 +380,7 @@ class ConfluenceLoader(BaseLoader):
)
if not include_restricted_content and not self.is_public_page(page):
continue
doc = self.process_page(
yield self.process_page(
page,
include_attachments,
include_comments,
@ -351,9 +388,12 @@ class ConfluenceLoader(BaseLoader):
ocr_languages,
keep_markdown_format,
)
docs.append(doc)
return docs
def load(self, **kwargs: Any) -> List[Document]:
return list(self._lazy_load(**kwargs))
def lazy_load(self) -> Iterator[Document]:
yield from self._lazy_load()
def _search_content_by_cql(
self, cql: str, include_archived_spaces: Optional[bool] = None, **kwargs: Any
@ -430,13 +470,12 @@ class ConfluenceLoader(BaseLoader):
ocr_languages: Optional[str] = None,
keep_markdown_format: Optional[bool] = False,
keep_newlines: bool = False,
) -> List[Document]:
) -> Iterator[Document]:
"""Process a list of pages into a list of documents."""
docs = []
for page in pages:
if not include_restricted_content and not self.is_public_page(page):
continue
doc = self.process_page(
yield self.process_page(
page,
include_attachments,
include_comments,
@ -445,9 +484,6 @@ class ConfluenceLoader(BaseLoader):
keep_markdown_format=keep_markdown_format,
keep_newlines=keep_newlines,
)
docs.append(doc)
return docs
def process_page(
self,

View File

@ -1,5 +1,5 @@
import unittest
from typing import Dict
from typing import Any, Dict
from unittest.mock import MagicMock, patch
import pytest
@ -108,10 +108,12 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"),
]
confluence_loader = self._get_mock_confluence_loader(mock_confluence)
mock_page_ids = ["123", "456"]
documents = confluence_loader.load(page_ids=mock_page_ids)
confluence_loader = self._get_mock_confluence_loader(
mock_confluence, page_ids=mock_page_ids
)
documents = list(confluence_loader.lazy_load())
assert mock_confluence.get_page_by_id.call_count == 2
assert mock_confluence.get_all_restrictions_for_content.call_count == 2
@ -139,9 +141,11 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"),
]
confluence_loader = self._get_mock_confluence_loader(mock_confluence)
confluence_loader = self._get_mock_confluence_loader(
mock_confluence, space_key=self.MOCK_SPACE_KEY, max_pages=2
)
documents = confluence_loader.load(space_key=self.MOCK_SPACE_KEY, max_pages=2)
documents = confluence_loader.load()
assert mock_confluence.get_all_pages_from_space.call_count == 1
@ -155,6 +159,7 @@ class TestConfluenceLoader:
assert mock_confluence.cql.call_count == 0
assert mock_confluence.get_page_child_by_type.call_count == 0
@pytest.mark.requires("markdownify")
def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled(
self, mock_confluence: MagicMock
) -> None:
@ -168,15 +173,16 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"),
]
confluence_loader = self._get_mock_confluence_loader(mock_confluence)
documents = confluence_loader.load(
confluence_loader = self._get_mock_confluence_loader(
mock_confluence,
space_key=self.MOCK_SPACE_KEY,
content_format=ContentFormat.VIEW,
keep_markdown_format=True,
max_pages=2,
)
documents = confluence_loader.load()
assert mock_confluence.get_all_pages_from_space.call_count == 1
assert len(documents) == 2
@ -190,12 +196,13 @@ class TestConfluenceLoader:
assert mock_confluence.get_page_child_by_type.call_count == 0
def _get_mock_confluence_loader(
self, mock_confluence: MagicMock
self, mock_confluence: MagicMock, **kwargs: Any
) -> ConfluenceLoader:
confluence_loader = ConfluenceLoader(
self.CONFLUENCE_URL,
username=self.MOCK_USERNAME,
api_key=self.MOCK_API_TOKEN,
**kwargs,
)
confluence_loader.confluence = mock_confluence
return confluence_loader