tests: init retriever standard tests (#28459)

This commit is contained in:
Erick Friis 2024-12-02 15:36:09 -08:00 committed by GitHub
parent 42d40d694b
commit 000be1f32c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from .base_store import BaseStoreAsyncTests, BaseStoreSyncTests
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
from .chat_models import ChatModelIntegrationTests
from .embeddings import EmbeddingsIntegrationTests
from .retrievers import RetrieversIntegrationTests
from .tools import ToolsIntegrationTests
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite
@ -33,4 +34,5 @@ __all__ = [
"SyncCacheTestSuite",
"AsyncReadWriteTestSuite",
"ReadWriteTestSuite",
"RetrieversIntegrationTests",
]

View File

@ -0,0 +1,78 @@
from abc import abstractmethod
from typing import Type
import pytest
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_tests.base import BaseStandardTests
class RetrieversIntegrationTests(BaseStandardTests):
@property
@abstractmethod
def retriever_constructor(self) -> Type[BaseRetriever]: ...
@property
def retriever_constructor_params(self) -> dict:
return {}
@property
@abstractmethod
def retriever_query_example(self) -> str:
"""
Returns a dictionary representing the "args" of an example retriever call.
"""
...
@pytest.fixture
def retriever(self) -> BaseRetriever:
return self.retriever_constructor(**self.retriever_constructor_params)
def test_k_constructor_param(self) -> None:
"""
Test that the retriever constructor accepts a k parameter.
"""
params = {
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
}
params_3 = {**params, "k": 3}
retriever_3 = self.retriever_constructor(**params_3)
result_3 = retriever_3.invoke(self.retriever_query_example)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
params_1 = {**params, "k": 1}
retriever_1 = self.retriever_constructor(**params_1)
result_1 = retriever_1.invoke(self.retriever_query_example)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
result_1 = retriever.invoke(self.retriever_query_example, k=1)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
result_3 = retriever.invoke(self.retriever_query_example, k=3)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
def test_invoke_returns_documents(self, retriever: BaseRetriever) -> None:
"""
If invoked with the example params, the retriever should return a list of
Documents.
"""
result = retriever.invoke(self.retriever_query_example)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)
async def test_ainvoke_returns_documents(self, retriever: BaseRetriever) -> None:
"""
If ainvoked with the example params, the retriever should return a list of
Documents.
"""
result = await retriever.ainvoke(self.retriever_query_example)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)

View File

@ -0,0 +1,32 @@
from typing import Any, Type
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_tests.integration_tests import RetrieversIntegrationTests
class ParrotRetriever(BaseRetriever):
parrot_name: str
k: int = 3
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
k = kwargs.get("k", self.k)
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
class TestParrotRetrieverIntegration(RetrieversIntegrationTests):
@property
def retriever_constructor(self) -> Type[ParrotRetriever]:
return ParrotRetriever
@property
def retriever_constructor_params(self) -> dict:
return {"parrot_name": "Polly"}
@property
def retriever_query_example(self) -> str:
return "parrot"