tests: init retriever standard tests (#28459)

This commit is contained in:
Erick Friis 2024-12-02 15:36:09 -08:00 committed by GitHub
parent 42d40d694b
commit 000be1f32c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from .base_store import BaseStoreAsyncTests, BaseStoreSyncTests
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
from .chat_models import ChatModelIntegrationTests from .chat_models import ChatModelIntegrationTests
from .embeddings import EmbeddingsIntegrationTests from .embeddings import EmbeddingsIntegrationTests
from .retrievers import RetrieversIntegrationTests
from .tools import ToolsIntegrationTests from .tools import ToolsIntegrationTests
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite
@ -33,4 +34,5 @@ __all__ = [
"SyncCacheTestSuite", "SyncCacheTestSuite",
"AsyncReadWriteTestSuite", "AsyncReadWriteTestSuite",
"ReadWriteTestSuite", "ReadWriteTestSuite",
"RetrieversIntegrationTests",
] ]

View File

@ -0,0 +1,78 @@
from abc import abstractmethod
from typing import Type
import pytest
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_tests.base import BaseStandardTests
class RetrieversIntegrationTests(BaseStandardTests):
@property
@abstractmethod
def retriever_constructor(self) -> Type[BaseRetriever]: ...
@property
def retriever_constructor_params(self) -> dict:
return {}
@property
@abstractmethod
def retriever_query_example(self) -> str:
"""
Returns a dictionary representing the "args" of an example retriever call.
"""
...
@pytest.fixture
def retriever(self) -> BaseRetriever:
return self.retriever_constructor(**self.retriever_constructor_params)
def test_k_constructor_param(self) -> None:
"""
Test that the retriever constructor accepts a k parameter.
"""
params = {
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
}
params_3 = {**params, "k": 3}
retriever_3 = self.retriever_constructor(**params_3)
result_3 = retriever_3.invoke(self.retriever_query_example)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
params_1 = {**params, "k": 1}
retriever_1 = self.retriever_constructor(**params_1)
result_1 = retriever_1.invoke(self.retriever_query_example)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
result_1 = retriever.invoke(self.retriever_query_example, k=1)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
result_3 = retriever.invoke(self.retriever_query_example, k=3)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
def test_invoke_returns_documents(self, retriever: BaseRetriever) -> None:
"""
If invoked with the example params, the retriever should return a list of
Documents.
"""
result = retriever.invoke(self.retriever_query_example)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)
async def test_ainvoke_returns_documents(self, retriever: BaseRetriever) -> None:
"""
If ainvoked with the example params, the retriever should return a list of
Documents.
"""
result = await retriever.ainvoke(self.retriever_query_example)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)

View File

@ -0,0 +1,32 @@
from typing import Any, Type
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_tests.integration_tests import RetrieversIntegrationTests
class ParrotRetriever(BaseRetriever):
parrot_name: str
k: int = 3
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
k = kwargs.get("k", self.k)
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
class TestParrotRetrieverIntegration(RetrieversIntegrationTests):
@property
def retriever_constructor(self) -> Type[ParrotRetriever]:
return ParrotRetriever
@property
def retriever_constructor_params(self) -> dict:
return {"parrot_name": "Polly"}
@property
def retriever_query_example(self) -> str:
return "parrot"