Merge pull request #18436

* Implement lazy_load() for ConfluenceLoader
This commit is contained in:
Christophe Bornet 2024-03-06 19:15:24 +01:00 committed by GitHub
parent 691480f491
commit bb284eebe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 132 additions and 89 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import requests import requests
from langchain_core.documents import Document from langchain_core.documents import Document
@ -49,7 +49,7 @@ class ConfluenceLoader(BaseLoader):
Confluence API supports difference format of page content. The storage format is the Confluence API supports difference format of page content. The storage format is the
raw XML representation for storage. The view format is the HTML representation for raw XML representation for storage. The view format is the HTML representation for
viewing with macros are rendered as though it is viewed by users. You can pass viewing with macros are rendered as though it is viewed by users. You can pass
a enum `content_format` argument to `load()` to specify the content format, this is a enum `content_format` argument to specify the content format, this is
set to `ContentFormat.STORAGE` by default, the supported values are: set to `ContentFormat.STORAGE` by default, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`, `ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`, `ContentFormat.STORAGE`, `ContentFormat.ANONYMOUS_EXPORT_VIEW`, `ContentFormat.STORAGE`,
@ -66,18 +66,22 @@ class ConfluenceLoader(BaseLoader):
loader = ConfluenceLoader( loader = ConfluenceLoader(
url="https://yoursite.atlassian.com/wiki", url="https://yoursite.atlassian.com/wiki",
username="me", username="me",
api_key="12345" api_key="12345",
space_key="SPACE",
limit=50,
) )
documents = loader.load(space_key="SPACE",limit=50) documents = loader.load()
# Server on perm # Server on perm
loader = ConfluenceLoader( loader = ConfluenceLoader(
url="https://confluence.yoursite.com/", url="https://confluence.yoursite.com/",
username="me", username="me",
api_key="your_password", api_key="your_password",
cloud=False cloud=False,
space_key="SPACE",
limit=50,
) )
documents = loader.load(space_key="SPACE",limit=50) documents = loader.load()
:param url: _description_ :param url: _description_
:type url: str :type url: str
@ -99,6 +103,43 @@ class ConfluenceLoader(BaseLoader):
:type max_retry_seconds: Optional[int], optional :type max_retry_seconds: Optional[int], optional
:param confluence_kwargs: additional kwargs to initialize confluence with :param confluence_kwargs: additional kwargs to initialize confluence with
:type confluence_kwargs: dict, optional :type confluence_kwargs: dict, optional
:param space_key: Space key retrieved from a confluence URL, defaults to None
:type space_key: Optional[str], optional
:param page_ids: List of specific page IDs to load, defaults to None
:type page_ids: Optional[List[str]], optional
:param label: Get all pages with this label, defaults to None
:type label: Optional[str], optional
:param cql: CQL Expression, defaults to None
:type cql: Optional[str], optional
:param include_restricted_content: defaults to False
:type include_restricted_content: bool, optional
:param include_archived_content: Whether to include archived content,
defaults to False
:type include_archived_content: bool, optional
:param include_attachments: defaults to False
:type include_attachments: bool, optional
:param include_comments: defaults to False
:type include_comments: bool, optional
:param content_format: Specify content format, defaults to
ContentFormat.STORAGE, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`,
`ContentFormat.STORAGE`, and `ContentFormat.VIEW`.
:type content_format: ContentFormat
:param limit: Maximum number of pages to retrieve per request, defaults to 50
:type limit: int, optional
:param max_pages: Maximum number of pages to retrieve in total, defaults 1000
:type max_pages: int, optional
:param ocr_languages: The languages to use for the Tesseract agent. To use a
language, you'll first need to install the appropriate
Tesseract language pack.
:type ocr_languages: str, optional
:param keep_markdown_format: Whether to keep the markdown format, defaults to
False
:type keep_markdown_format: bool
:param keep_newlines: Whether to keep the newlines format, defaults to
False
:type keep_newlines: bool
:raises ValueError: Errors while validating input :raises ValueError: Errors while validating input
:raises ImportError: Required dependencies not installed. :raises ImportError: Required dependencies not installed.
""" """
@ -116,7 +157,37 @@ class ConfluenceLoader(BaseLoader):
min_retry_seconds: Optional[int] = 2, min_retry_seconds: Optional[int] = 2,
max_retry_seconds: Optional[int] = 10, max_retry_seconds: Optional[int] = 10,
confluence_kwargs: Optional[dict] = None, confluence_kwargs: Optional[dict] = None,
*,
space_key: Optional[str] = None,
page_ids: Optional[List[str]] = None,
label: Optional[str] = None,
cql: Optional[str] = None,
include_restricted_content: bool = False,
include_archived_content: bool = False,
include_attachments: bool = False,
include_comments: bool = False,
content_format: ContentFormat = ContentFormat.STORAGE,
limit: Optional[int] = 50,
max_pages: Optional[int] = 1000,
ocr_languages: Optional[str] = None,
keep_markdown_format: bool = False,
keep_newlines: bool = False,
): ):
self.space_key = space_key
self.page_ids = page_ids
self.label = label
self.cql = cql
self.include_restricted_content = include_restricted_content
self.include_archived_content = include_archived_content
self.include_attachments = include_attachments
self.include_comments = include_comments
self.content_format = content_format
self.limit = limit
self.max_pages = max_pages
self.ocr_languages = ocr_languages
self.keep_markdown_format = keep_markdown_format
self.keep_newlines = keep_newlines
confluence_kwargs = confluence_kwargs or {} confluence_kwargs = confluence_kwargs or {}
errors = ConfluenceLoader.validate_init_args( errors = ConfluenceLoader.validate_init_args(
url=url, url=url,
@ -204,74 +275,40 @@ class ConfluenceLoader(BaseLoader):
) )
return errors or None return errors or None
def load( def _resolve_param(self, param_name: str, kwargs: Any) -> Any:
self, return kwargs[param_name] if param_name in kwargs else getattr(self, param_name)
space_key: Optional[str] = None,
page_ids: Optional[List[str]] = None, def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
label: Optional[str] = None, if kwargs:
cql: Optional[str] = None, logger.warning(
include_restricted_content: bool = False, f"Received runtime arguments {kwargs}. Passing runtime args to `load`"
include_archived_content: bool = False, f" is deprecated. Please pass arguments during initialization instead."
include_attachments: bool = False, )
include_comments: bool = False, space_key = self._resolve_param("space_key", kwargs)
content_format: ContentFormat = ContentFormat.STORAGE, page_ids = self._resolve_param("page_ids", kwargs)
limit: Optional[int] = 50, label = self._resolve_param("label", kwargs)
max_pages: Optional[int] = 1000, cql = self._resolve_param("cql", kwargs)
ocr_languages: Optional[str] = None, include_restricted_content = self._resolve_param(
keep_markdown_format: bool = False, "include_restricted_content", kwargs
keep_newlines: bool = False, )
) -> List[Document]: include_archived_content = self._resolve_param(
""" "include_archived_content", kwargs
:param space_key: Space key retrieved from a confluence URL, defaults to None )
:type space_key: Optional[str], optional include_attachments = self._resolve_param("include_attachments", kwargs)
:param page_ids: List of specific page IDs to load, defaults to None include_comments = self._resolve_param("include_comments", kwargs)
:type page_ids: Optional[List[str]], optional content_format = self._resolve_param("content_format", kwargs)
:param label: Get all pages with this label, defaults to None limit = self._resolve_param("limit", kwargs)
:type label: Optional[str], optional max_pages = self._resolve_param("max_pages", kwargs)
:param cql: CQL Expression, defaults to None ocr_languages = self._resolve_param("ocr_languages", kwargs)
:type cql: Optional[str], optional keep_markdown_format = self._resolve_param("keep_markdown_format", kwargs)
:param include_restricted_content: defaults to False keep_newlines = self._resolve_param("keep_newlines", kwargs)
:type include_restricted_content: bool, optional
:param include_archived_content: Whether to include archived content,
defaults to False
:type include_archived_content: bool, optional
:param include_attachments: defaults to False
:type include_attachments: bool, optional
:param include_comments: defaults to False
:type include_comments: bool, optional
:param content_format: Specify content format, defaults to
ContentFormat.STORAGE, the supported values are:
`ContentFormat.EDITOR`, `ContentFormat.EXPORT_VIEW`,
`ContentFormat.ANONYMOUS_EXPORT_VIEW`,
`ContentFormat.STORAGE`, and `ContentFormat.VIEW`.
:type content_format: ContentFormat
:param limit: Maximum number of pages to retrieve per request, defaults to 50
:type limit: int, optional
:param max_pages: Maximum number of pages to retrieve in total, defaults 1000
:type max_pages: int, optional
:param ocr_languages: The languages to use for the Tesseract agent. To use a
language, you'll first need to install the appropriate
Tesseract language pack.
:type ocr_languages: str, optional
:param keep_markdown_format: Whether to keep the markdown format, defaults to
False
:type keep_markdown_format: bool
:param keep_newlines: Whether to keep the newlines format, defaults to
False
:type keep_newlines: bool
:raises ValueError: _description_
:raises ImportError: _description_
:return: _description_
:rtype: List[Document]
"""
if not space_key and not page_ids and not label and not cql: if not space_key and not page_ids and not label and not cql:
raise ValueError( raise ValueError(
"Must specify at least one among `space_key`, `page_ids`, " "Must specify at least one among `space_key`, `page_ids`, "
"`label`, `cql` parameters." "`label`, `cql` parameters."
) )
docs = []
if space_key: if space_key:
pages = self.paginate_request( pages = self.paginate_request(
self.confluence.get_all_pages_from_space, self.confluence.get_all_pages_from_space,
@ -281,7 +318,7 @@ class ConfluenceLoader(BaseLoader):
status="any" if include_archived_content else "current", status="any" if include_archived_content else "current",
expand=content_format.value, expand=content_format.value,
) )
docs += self.process_pages( yield from self.process_pages(
pages, pages,
include_restricted_content, include_restricted_content,
include_attachments, include_attachments,
@ -314,7 +351,7 @@ class ConfluenceLoader(BaseLoader):
include_archived_spaces=include_archived_content, include_archived_spaces=include_archived_content,
expand=content_format.value, expand=content_format.value,
) )
docs += self.process_pages( yield from self.process_pages(
pages, pages,
include_restricted_content, include_restricted_content,
include_attachments, include_attachments,
@ -343,7 +380,7 @@ class ConfluenceLoader(BaseLoader):
) )
if not include_restricted_content and not self.is_public_page(page): if not include_restricted_content and not self.is_public_page(page):
continue continue
doc = self.process_page( yield self.process_page(
page, page,
include_attachments, include_attachments,
include_comments, include_comments,
@ -351,9 +388,12 @@ class ConfluenceLoader(BaseLoader):
ocr_languages, ocr_languages,
keep_markdown_format, keep_markdown_format,
) )
docs.append(doc)
return docs def load(self, **kwargs: Any) -> List[Document]:
return list(self._lazy_load(**kwargs))
def lazy_load(self) -> Iterator[Document]:
yield from self._lazy_load()
def _search_content_by_cql( def _search_content_by_cql(
self, cql: str, include_archived_spaces: Optional[bool] = None, **kwargs: Any self, cql: str, include_archived_spaces: Optional[bool] = None, **kwargs: Any
@ -430,13 +470,12 @@ class ConfluenceLoader(BaseLoader):
ocr_languages: Optional[str] = None, ocr_languages: Optional[str] = None,
keep_markdown_format: Optional[bool] = False, keep_markdown_format: Optional[bool] = False,
keep_newlines: bool = False, keep_newlines: bool = False,
) -> List[Document]: ) -> Iterator[Document]:
"""Process a list of pages into a list of documents.""" """Process a list of pages into a list of documents."""
docs = []
for page in pages: for page in pages:
if not include_restricted_content and not self.is_public_page(page): if not include_restricted_content and not self.is_public_page(page):
continue continue
doc = self.process_page( yield self.process_page(
page, page,
include_attachments, include_attachments,
include_comments, include_comments,
@ -445,9 +484,6 @@ class ConfluenceLoader(BaseLoader):
keep_markdown_format=keep_markdown_format, keep_markdown_format=keep_markdown_format,
keep_newlines=keep_newlines, keep_newlines=keep_newlines,
) )
docs.append(doc)
return docs
def process_page( def process_page(
self, self,

View File

@ -1,5 +1,5 @@
import unittest import unittest
from typing import Dict from typing import Any, Dict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -108,10 +108,12 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"), self._get_mock_page_restrictions("456"),
] ]
confluence_loader = self._get_mock_confluence_loader(mock_confluence)
mock_page_ids = ["123", "456"] mock_page_ids = ["123", "456"]
documents = confluence_loader.load(page_ids=mock_page_ids) confluence_loader = self._get_mock_confluence_loader(
mock_confluence, page_ids=mock_page_ids
)
documents = list(confluence_loader.lazy_load())
assert mock_confluence.get_page_by_id.call_count == 2 assert mock_confluence.get_page_by_id.call_count == 2
assert mock_confluence.get_all_restrictions_for_content.call_count == 2 assert mock_confluence.get_all_restrictions_for_content.call_count == 2
@ -139,9 +141,11 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"), self._get_mock_page_restrictions("456"),
] ]
confluence_loader = self._get_mock_confluence_loader(mock_confluence) confluence_loader = self._get_mock_confluence_loader(
mock_confluence, space_key=self.MOCK_SPACE_KEY, max_pages=2
)
documents = confluence_loader.load(space_key=self.MOCK_SPACE_KEY, max_pages=2) documents = confluence_loader.load()
assert mock_confluence.get_all_pages_from_space.call_count == 1 assert mock_confluence.get_all_pages_from_space.call_count == 1
@ -155,6 +159,7 @@ class TestConfluenceLoader:
assert mock_confluence.cql.call_count == 0 assert mock_confluence.cql.call_count == 0
assert mock_confluence.get_page_child_by_type.call_count == 0 assert mock_confluence.get_page_child_by_type.call_count == 0
@pytest.mark.requires("markdownify")
def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled( def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled(
self, mock_confluence: MagicMock self, mock_confluence: MagicMock
) -> None: ) -> None:
@ -168,15 +173,16 @@ class TestConfluenceLoader:
self._get_mock_page_restrictions("456"), self._get_mock_page_restrictions("456"),
] ]
confluence_loader = self._get_mock_confluence_loader(mock_confluence) confluence_loader = self._get_mock_confluence_loader(
mock_confluence,
documents = confluence_loader.load(
space_key=self.MOCK_SPACE_KEY, space_key=self.MOCK_SPACE_KEY,
content_format=ContentFormat.VIEW, content_format=ContentFormat.VIEW,
keep_markdown_format=True, keep_markdown_format=True,
max_pages=2, max_pages=2,
) )
documents = confluence_loader.load()
assert mock_confluence.get_all_pages_from_space.call_count == 1 assert mock_confluence.get_all_pages_from_space.call_count == 1
assert len(documents) == 2 assert len(documents) == 2
@ -190,12 +196,13 @@ class TestConfluenceLoader:
assert mock_confluence.get_page_child_by_type.call_count == 0 assert mock_confluence.get_page_child_by_type.call_count == 0
def _get_mock_confluence_loader( def _get_mock_confluence_loader(
self, mock_confluence: MagicMock self, mock_confluence: MagicMock, **kwargs: Any
) -> ConfluenceLoader: ) -> ConfluenceLoader:
confluence_loader = ConfluenceLoader( confluence_loader = ConfluenceLoader(
self.CONFLUENCE_URL, self.CONFLUENCE_URL,
username=self.MOCK_USERNAME, username=self.MOCK_USERNAME,
api_key=self.MOCK_API_TOKEN, api_key=self.MOCK_API_TOKEN,
**kwargs,
) )
confluence_loader.confluence = mock_confluence confluence_loader.confluence = mock_confluence
return confluence_loader return confluence_loader