langchain/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

135 lines
3.8 KiB
Python

"""
Unit tests for the PebbloRetrievalQA chain
"""
from typing import List
from unittest.mock import Mock
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.vectorstores import (
InMemoryVectorStore,
VectorStore,
VectorStoreRetriever,
)
from langchain_community.chains import PebbloRetrievalQA
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
ChainInput,
SemanticContext,
)
from langchain_community.vectorstores.pinecone import Pinecone
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRetriever(VectorStoreRetriever):
"""
Test util that parrots the query back as documents
"""
vectorstore: VectorStore = Mock()
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return [Document(page_content=query)]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
return [Document(page_content=query)]
@pytest.fixture
def retriever() -> FakeRetriever:
"""
Create a FakeRetriever instance
"""
retriever = FakeRetriever()
retriever.search_kwargs = {}
# Set the class of vectorstore to Pinecone
retriever.vectorstore.__class__ = Pinecone
return retriever
@pytest.fixture
def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
"""
Create a PebbloRetrievalQA instance
"""
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
return pebblo_retrieval_qa
def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None:
"""
Test that the invoke method returns a non-None result
"""
# Create a fake auth context and semantic context
auth_context = AuthContext(
user_id="fake_user@email.com",
user_auth=["fake-group", "fake-group2"],
)
semantic_context_dict = {
"pebblo_semantic_topics": {"deny": ["harmful-advice"]},
"pebblo_semantic_entities": {"deny": ["credit-card"]},
}
semantic_context = SemanticContext(**semantic_context_dict)
question = "What is the meaning of life?"
chain_input_obj = ChainInput(
query=question, auth_context=auth_context, semantic_context=semantic_context
)
response = pebblo_retrieval_qa.invoke(chain_input_obj.dict())
assert response is not None
def test_validate_vectorstore(retriever: FakeRetriever) -> None:
"""
Test vectorstore validation
"""
# No exception should be raised for supported vectorstores (Pinecone)
_ = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
unsupported_retriever = FakeRetriever()
unsupported_retriever.search_kwargs = {}
# Set the class of vectorstore
unsupported_retriever.vectorstore.__class__ = InMemoryVectorStore
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
with pytest.raises(ValueError) as exc_info:
_ = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(),
chain_type="stuff",
retriever=unsupported_retriever,
owner="owner",
description="description",
app_name="app_name",
)
assert (
"Vectorstore must be an instance of one of the supported vectorstores"
in str(exc_info.value)
)