mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
improve confluence doc loader param validation (#9568)
This commit is contained in:
parent
0fea987dd2
commit
04f2d69b83
@ -118,16 +118,15 @@ class ConfluenceLoader(BaseLoader):
|
||||
):
|
||||
confluence_kwargs = confluence_kwargs or {}
|
||||
errors = ConfluenceLoader.validate_init_args(
|
||||
url, api_key, username, oauth2, token
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
username=username,
|
||||
session=session,
|
||||
oauth2=oauth2,
|
||||
token=token,
|
||||
)
|
||||
if errors:
|
||||
raise ValueError(f"Error(s) while validating input: {errors}")
|
||||
|
||||
self.base_url = url
|
||||
self.number_of_retries = number_of_retries
|
||||
self.min_retry_seconds = min_retry_seconds
|
||||
self.max_retry_seconds = max_retry_seconds
|
||||
|
||||
try:
|
||||
from atlassian import Confluence # noqa: F401
|
||||
except ImportError:
|
||||
@ -136,6 +135,11 @@ class ConfluenceLoader(BaseLoader):
|
||||
"`pip install atlassian-python-api`"
|
||||
)
|
||||
|
||||
self.base_url = url
|
||||
self.number_of_retries = number_of_retries
|
||||
self.min_retry_seconds = min_retry_seconds
|
||||
self.max_retry_seconds = max_retry_seconds
|
||||
|
||||
if session:
|
||||
self.confluence = Confluence(url=url, session=session, **confluence_kwargs)
|
||||
elif oauth2:
|
||||
@ -160,6 +164,7 @@ class ConfluenceLoader(BaseLoader):
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
session: Optional[requests.Session] = None,
|
||||
oauth2: Optional[dict] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> Union[List, None]:
|
||||
@ -175,33 +180,28 @@ class ConfluenceLoader(BaseLoader):
|
||||
"the other must be as well."
|
||||
)
|
||||
|
||||
if (api_key or username) and oauth2:
|
||||
errors.append(
|
||||
"Cannot provide a value for `api_key` and/or "
|
||||
"`username` and provide a value for `oauth2`"
|
||||
non_null_creds = list(
|
||||
x is not None for x in ((api_key or username), session, oauth2, token)
|
||||
)
|
||||
|
||||
if oauth2 and oauth2.keys() != [
|
||||
if sum(non_null_creds) > 1:
|
||||
all_names = ("(api_key, username)", "session", "oath2", "token")
|
||||
provided = tuple(n for x, n in zip(non_null_creds, all_names) if x)
|
||||
errors.append(
|
||||
f"Cannot provide a value for more than one of: {all_names}. Received "
|
||||
f"values for: {provided}"
|
||||
)
|
||||
if oauth2 and set(oauth2.keys()) != {
|
||||
"access_token",
|
||||
"access_token_secret",
|
||||
"consumer_key",
|
||||
"key_cert",
|
||||
]:
|
||||
}:
|
||||
errors.append(
|
||||
"You have either omitted require keys or added extra "
|
||||
"keys to the oauth2 dictionary. key values should be "
|
||||
"`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`"
|
||||
)
|
||||
|
||||
if token and (api_key or username or oauth2):
|
||||
errors.append(
|
||||
"Cannot provide a value for `token` and a value for `api_key`, "
|
||||
"`username` or `oauth2`"
|
||||
)
|
||||
|
||||
if errors:
|
||||
return errors
|
||||
return None
|
||||
return errors or None
|
||||
|
||||
def load(
|
||||
self,
|
||||
|
@ -3,6 +3,7 @@ from typing import Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.confluence import ConfluenceLoader
|
||||
@ -23,7 +24,7 @@ class TestConfluenceLoader:
|
||||
|
||||
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
||||
ConfluenceLoader(
|
||||
url=self.CONFLUENCE_URL,
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
)
|
||||
@ -34,6 +35,36 @@ class TestConfluenceLoader:
|
||||
cloud=True,
|
||||
)
|
||||
|
||||
def test_confluence_loader_initialization_invalid(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConfluenceLoader(
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
token="foo",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ConfluenceLoader(
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
oauth2={
|
||||
"access_token": "bar",
|
||||
"access_token_secret": "bar",
|
||||
"consumer_key": "bar",
|
||||
"key_cert": "bar",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ConfluenceLoader(
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
session=requests.Session(),
|
||||
)
|
||||
|
||||
def test_confluence_loader_initialization_from_env(
|
||||
self, mock_confluence: MagicMock
|
||||
) -> None:
|
||||
@ -51,7 +82,7 @@ class TestConfluenceLoader:
|
||||
|
||||
def test_confluence_loader_load_data_invalid_args(self) -> None:
|
||||
confluence_loader = ConfluenceLoader(
|
||||
url=self.CONFLUENCE_URL,
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
)
|
||||
@ -125,7 +156,7 @@ class TestConfluenceLoader:
|
||||
self, mock_confluence: MagicMock
|
||||
) -> ConfluenceLoader:
|
||||
confluence_loader = ConfluenceLoader(
|
||||
url=self.CONFLUENCE_URL,
|
||||
self.CONFLUENCE_URL,
|
||||
username=self.MOCK_USERNAME,
|
||||
api_key=self.MOCK_API_TOKEN,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user