mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
improve confluence doc loader param validation (#9568)
This commit is contained in:
parent
0fea987dd2
commit
04f2d69b83
@ -118,16 +118,15 @@ class ConfluenceLoader(BaseLoader):
|
|||||||
):
|
):
|
||||||
confluence_kwargs = confluence_kwargs or {}
|
confluence_kwargs = confluence_kwargs or {}
|
||||||
errors = ConfluenceLoader.validate_init_args(
|
errors = ConfluenceLoader.validate_init_args(
|
||||||
url, api_key, username, oauth2, token
|
url=url,
|
||||||
|
api_key=api_key,
|
||||||
|
username=username,
|
||||||
|
session=session,
|
||||||
|
oauth2=oauth2,
|
||||||
|
token=token,
|
||||||
)
|
)
|
||||||
if errors:
|
if errors:
|
||||||
raise ValueError(f"Error(s) while validating input: {errors}")
|
raise ValueError(f"Error(s) while validating input: {errors}")
|
||||||
|
|
||||||
self.base_url = url
|
|
||||||
self.number_of_retries = number_of_retries
|
|
||||||
self.min_retry_seconds = min_retry_seconds
|
|
||||||
self.max_retry_seconds = max_retry_seconds
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from atlassian import Confluence # noqa: F401
|
from atlassian import Confluence # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -136,6 +135,11 @@ class ConfluenceLoader(BaseLoader):
|
|||||||
"`pip install atlassian-python-api`"
|
"`pip install atlassian-python-api`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.base_url = url
|
||||||
|
self.number_of_retries = number_of_retries
|
||||||
|
self.min_retry_seconds = min_retry_seconds
|
||||||
|
self.max_retry_seconds = max_retry_seconds
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
self.confluence = Confluence(url=url, session=session, **confluence_kwargs)
|
self.confluence = Confluence(url=url, session=session, **confluence_kwargs)
|
||||||
elif oauth2:
|
elif oauth2:
|
||||||
@ -160,6 +164,7 @@ class ConfluenceLoader(BaseLoader):
|
|||||||
url: Optional[str] = None,
|
url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
|
session: Optional[requests.Session] = None,
|
||||||
oauth2: Optional[dict] = None,
|
oauth2: Optional[dict] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
) -> Union[List, None]:
|
) -> Union[List, None]:
|
||||||
@ -175,33 +180,28 @@ class ConfluenceLoader(BaseLoader):
|
|||||||
"the other must be as well."
|
"the other must be as well."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (api_key or username) and oauth2:
|
non_null_creds = list(
|
||||||
|
x is not None for x in ((api_key or username), session, oauth2, token)
|
||||||
|
)
|
||||||
|
if sum(non_null_creds) > 1:
|
||||||
|
all_names = ("(api_key, username)", "session", "oath2", "token")
|
||||||
|
provided = tuple(n for x, n in zip(non_null_creds, all_names) if x)
|
||||||
errors.append(
|
errors.append(
|
||||||
"Cannot provide a value for `api_key` and/or "
|
f"Cannot provide a value for more than one of: {all_names}. Received "
|
||||||
"`username` and provide a value for `oauth2`"
|
f"values for: {provided}"
|
||||||
)
|
)
|
||||||
|
if oauth2 and set(oauth2.keys()) != {
|
||||||
if oauth2 and oauth2.keys() != [
|
|
||||||
"access_token",
|
"access_token",
|
||||||
"access_token_secret",
|
"access_token_secret",
|
||||||
"consumer_key",
|
"consumer_key",
|
||||||
"key_cert",
|
"key_cert",
|
||||||
]:
|
}:
|
||||||
errors.append(
|
errors.append(
|
||||||
"You have either omitted require keys or added extra "
|
"You have either omitted require keys or added extra "
|
||||||
"keys to the oauth2 dictionary. key values should be "
|
"keys to the oauth2 dictionary. key values should be "
|
||||||
"`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`"
|
"`['access_token', 'access_token_secret', 'consumer_key', 'key_cert']`"
|
||||||
)
|
)
|
||||||
|
return errors or None
|
||||||
if token and (api_key or username or oauth2):
|
|
||||||
errors.append(
|
|
||||||
"Cannot provide a value for `token` and a value for `api_key`, "
|
|
||||||
"`username` or `oauth2`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if errors:
|
|
||||||
return errors
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
|
@ -3,6 +3,7 @@ from typing import Dict
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.document_loaders.confluence import ConfluenceLoader
|
from langchain.document_loaders.confluence import ConfluenceLoader
|
||||||
@ -23,7 +24,7 @@ class TestConfluenceLoader:
|
|||||||
|
|
||||||
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
||||||
ConfluenceLoader(
|
ConfluenceLoader(
|
||||||
url=self.CONFLUENCE_URL,
|
self.CONFLUENCE_URL,
|
||||||
username=self.MOCK_USERNAME,
|
username=self.MOCK_USERNAME,
|
||||||
api_key=self.MOCK_API_TOKEN,
|
api_key=self.MOCK_API_TOKEN,
|
||||||
)
|
)
|
||||||
@ -34,6 +35,36 @@ class TestConfluenceLoader:
|
|||||||
cloud=True,
|
cloud=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_confluence_loader_initialization_invalid(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConfluenceLoader(
|
||||||
|
self.CONFLUENCE_URL,
|
||||||
|
username=self.MOCK_USERNAME,
|
||||||
|
api_key=self.MOCK_API_TOKEN,
|
||||||
|
token="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConfluenceLoader(
|
||||||
|
self.CONFLUENCE_URL,
|
||||||
|
username=self.MOCK_USERNAME,
|
||||||
|
api_key=self.MOCK_API_TOKEN,
|
||||||
|
oauth2={
|
||||||
|
"access_token": "bar",
|
||||||
|
"access_token_secret": "bar",
|
||||||
|
"consumer_key": "bar",
|
||||||
|
"key_cert": "bar",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConfluenceLoader(
|
||||||
|
self.CONFLUENCE_URL,
|
||||||
|
username=self.MOCK_USERNAME,
|
||||||
|
api_key=self.MOCK_API_TOKEN,
|
||||||
|
session=requests.Session(),
|
||||||
|
)
|
||||||
|
|
||||||
def test_confluence_loader_initialization_from_env(
|
def test_confluence_loader_initialization_from_env(
|
||||||
self, mock_confluence: MagicMock
|
self, mock_confluence: MagicMock
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -51,7 +82,7 @@ class TestConfluenceLoader:
|
|||||||
|
|
||||||
def test_confluence_loader_load_data_invalid_args(self) -> None:
|
def test_confluence_loader_load_data_invalid_args(self) -> None:
|
||||||
confluence_loader = ConfluenceLoader(
|
confluence_loader = ConfluenceLoader(
|
||||||
url=self.CONFLUENCE_URL,
|
self.CONFLUENCE_URL,
|
||||||
username=self.MOCK_USERNAME,
|
username=self.MOCK_USERNAME,
|
||||||
api_key=self.MOCK_API_TOKEN,
|
api_key=self.MOCK_API_TOKEN,
|
||||||
)
|
)
|
||||||
@ -125,7 +156,7 @@ class TestConfluenceLoader:
|
|||||||
self, mock_confluence: MagicMock
|
self, mock_confluence: MagicMock
|
||||||
) -> ConfluenceLoader:
|
) -> ConfluenceLoader:
|
||||||
confluence_loader = ConfluenceLoader(
|
confluence_loader = ConfluenceLoader(
|
||||||
url=self.CONFLUENCE_URL,
|
self.CONFLUENCE_URL,
|
||||||
username=self.MOCK_USERNAME,
|
username=self.MOCK_USERNAME,
|
||||||
api_key=self.MOCK_API_TOKEN,
|
api_key=self.MOCK_API_TOKEN,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user