From 04f2d69b83d9f38fbfa865630b20f3d8f0a6a8b5 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:02:36 -0700 Subject: [PATCH] improve confluence doc loader param validation (#9568) --- .../langchain/document_loaders/confluence.py | 46 +++++++++---------- .../document_loaders/test_confluence.py | 37 +++++++++++++-- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/libs/langchain/langchain/document_loaders/confluence.py b/libs/langchain/langchain/document_loaders/confluence.py index 3d986a7153a..a1151c08e2c 100644 --- a/libs/langchain/langchain/document_loaders/confluence.py +++ b/libs/langchain/langchain/document_loaders/confluence.py @@ -118,16 +118,15 @@ class ConfluenceLoader(BaseLoader): ): confluence_kwargs = confluence_kwargs or {} errors = ConfluenceLoader.validate_init_args( - url, api_key, username, oauth2, token + url=url, + api_key=api_key, + username=username, + session=session, + oauth2=oauth2, + token=token, ) if errors: raise ValueError(f"Error(s) while validating input: {errors}") - - self.base_url = url - self.number_of_retries = number_of_retries - self.min_retry_seconds = min_retry_seconds - self.max_retry_seconds = max_retry_seconds - try: from atlassian import Confluence # noqa: F401 except ImportError: @@ -136,6 +135,11 @@ class ConfluenceLoader(BaseLoader): "`pip install atlassian-python-api`" ) + self.base_url = url + self.number_of_retries = number_of_retries + self.min_retry_seconds = min_retry_seconds + self.max_retry_seconds = max_retry_seconds + if session: self.confluence = Confluence(url=url, session=session, **confluence_kwargs) elif oauth2: @@ -160,6 +164,7 @@ class ConfluenceLoader(BaseLoader): url: Optional[str] = None, api_key: Optional[str] = None, username: Optional[str] = None, + session: Optional[requests.Session] = None, oauth2: Optional[dict] = None, token: Optional[str] = None, ) -> Union[List, None]: @@ -175,33 +180,28 @@ class ConfluenceLoader(BaseLoader): "the other must be as well." ) - if (api_key or username) and oauth2: + non_null_creds = list( + x is not None for x in ((api_key or username), session, oauth2, token) + ) + if sum(non_null_creds) > 1: + all_names = ("(api_key, username)", "session", "oath2", "token") + provided = tuple(n for x, n in zip(non_null_creds, all_names) if x) errors.append( - "Cannot provide a value for `api_key` and/or " - "`username` and provide a value for `oauth2`" + f"Cannot provide a value for more than one of: {all_names}. Received " + f"values for: {provided}" ) - - if oauth2 and oauth2.keys() != [ + if oauth2 and set(oauth2.keys()) != { "access_token", "access_token_secret", "consumer_key", "key_cert", - ]: + }: errors.append( "You have either omitted require keys or added extra " "keys to the oauth2 dictionary. key values should be " "`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`" ) - - if token and (api_key or username or oauth2): - errors.append( - "Cannot provide a value for `token` and a value for `api_key`, " - "`username` or `oauth2`" - ) - - if errors: - return errors - return None + return errors or None def load( self, diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_confluence.py b/libs/langchain/tests/unit_tests/document_loaders/test_confluence.py index a3d371bf783..42de78598a6 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_confluence.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_confluence.py @@ -3,6 +3,7 @@ from typing import Dict from unittest.mock import MagicMock, patch import pytest +import requests from langchain.docstore.document import Document from langchain.document_loaders.confluence import ConfluenceLoader @@ -23,7 +24,7 @@ class TestConfluenceLoader: def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None: ConfluenceLoader( - url=self.CONFLUENCE_URL, + self.CONFLUENCE_URL, username=self.MOCK_USERNAME, api_key=self.MOCK_API_TOKEN, ) @@ -34,6 +35,36 @@ class TestConfluenceLoader: cloud=True, ) + def test_confluence_loader_initialization_invalid(self) -> None: + with pytest.raises(ValueError): + ConfluenceLoader( + self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + token="foo", + ) + + with pytest.raises(ValueError): + ConfluenceLoader( + self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + oauth2={ + "access_token": "bar", + "access_token_secret": "bar", + "consumer_key": "bar", + "key_cert": "bar", + }, + ) + + with pytest.raises(ValueError): + ConfluenceLoader( + self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + session=requests.Session(), + ) + def test_confluence_loader_initialization_from_env( self, mock_confluence: MagicMock ) -> None: @@ -51,7 +82,7 @@ class TestConfluenceLoader: def test_confluence_loader_load_data_invalid_args(self) -> None: confluence_loader = ConfluenceLoader( - url=self.CONFLUENCE_URL, + self.CONFLUENCE_URL, username=self.MOCK_USERNAME, api_key=self.MOCK_API_TOKEN, ) @@ -125,7 +156,7 @@ class TestConfluenceLoader: self, mock_confluence: MagicMock ) -> ConfluenceLoader: confluence_loader = ConfluenceLoader( - url=self.CONFLUENCE_URL, + self.CONFLUENCE_URL, username=self.MOCK_USERNAME, api_key=self.MOCK_API_TOKEN, )