mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-01 13:26:15 +00:00
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from typing import List
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_client() -> MagicMock:
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_retriever_config() -> dict:
|
|
return {"vectorSearchConfiguration": {"numberOfResults": 4}}
|
|
|
|
|
|
@pytest.fixture
|
|
def amazon_retriever(
|
|
mock_client: MagicMock, mock_retriever_config: dict
|
|
) -> AmazonKnowledgeBasesRetriever:
|
|
return AmazonKnowledgeBasesRetriever(
|
|
knowledge_base_id="test_kb_id",
|
|
retrieval_config=mock_retriever_config, # type: ignore[arg-type]
|
|
client=mock_client,
|
|
)
|
|
|
|
|
|
def test_create_client() -> None:
|
|
# Import error if boto3 is not installed
|
|
# Value error if credentials are not supplied.
|
|
with pytest.raises((ImportError, ValueError)):
|
|
AmazonKnowledgeBasesRetriever() # type: ignore
|
|
|
|
|
|
def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
|
|
ls_params = amazon_retriever._get_ls_params()
|
|
assert ls_params == {"ls_retriever_name": "amazonknowledgebases"}
|
|
|
|
|
|
def test_get_relevant_documents(
|
|
amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock
|
|
) -> None:
|
|
query: str = "test query"
|
|
mock_client.retrieve.return_value = {
|
|
"retrievalResults": [
|
|
{"content": {"text": "result1"}, "metadata": {"key": "value1"}},
|
|
{
|
|
"content": {"text": "result2"},
|
|
"metadata": {"key": "value2"},
|
|
"score": 1,
|
|
"location": "testLocation",
|
|
},
|
|
{"content": {"text": "result3"}},
|
|
]
|
|
}
|
|
documents: List[Document] = amazon_retriever._get_relevant_documents(
|
|
query,
|
|
run_manager=None, # type: ignore
|
|
)
|
|
|
|
assert len(documents) == 3
|
|
assert isinstance(documents[0], Document)
|
|
assert documents[0].page_content == "result1"
|
|
assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}}
|
|
assert documents[1].page_content == "result2"
|
|
assert documents[1].metadata == {
|
|
"score": 1,
|
|
"source_metadata": {"key": "value2"},
|
|
"location": "testLocation",
|
|
}
|
|
assert documents[2].page_content == "result3"
|
|
assert documents[2].metadata == {"score": 0}
|