mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 01:48:57 +00:00
tests: init retriever standard tests (#28459)
This commit is contained in:
parent
42d40d694b
commit
000be1f32c
@ -20,6 +20,7 @@ from .base_store import BaseStoreAsyncTests, BaseStoreSyncTests
|
|||||||
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
|
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
|
||||||
from .chat_models import ChatModelIntegrationTests
|
from .chat_models import ChatModelIntegrationTests
|
||||||
from .embeddings import EmbeddingsIntegrationTests
|
from .embeddings import EmbeddingsIntegrationTests
|
||||||
|
from .retrievers import RetrieversIntegrationTests
|
||||||
from .tools import ToolsIntegrationTests
|
from .tools import ToolsIntegrationTests
|
||||||
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite
|
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite
|
||||||
|
|
||||||
@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"SyncCacheTestSuite",
|
"SyncCacheTestSuite",
|
||||||
"AsyncReadWriteTestSuite",
|
"AsyncReadWriteTestSuite",
|
||||||
"ReadWriteTestSuite",
|
"ReadWriteTestSuite",
|
||||||
|
"RetrieversIntegrationTests",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,78 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
from langchain_tests.base import BaseStandardTests
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieversIntegrationTests(BaseStandardTests):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def retriever_constructor(self) -> Type[BaseRetriever]: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retriever_constructor_params(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def retriever_query_example(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns a dictionary representing the "args" of an example retriever call.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def retriever(self) -> BaseRetriever:
|
||||||
|
return self.retriever_constructor(**self.retriever_constructor_params)
|
||||||
|
|
||||||
|
def test_k_constructor_param(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that the retriever constructor accepts a k parameter.
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
|
||||||
|
}
|
||||||
|
params_3 = {**params, "k": 3}
|
||||||
|
retriever_3 = self.retriever_constructor(**params_3)
|
||||||
|
result_3 = retriever_3.invoke(self.retriever_query_example)
|
||||||
|
assert len(result_3) == 3
|
||||||
|
assert all(isinstance(doc, Document) for doc in result_3)
|
||||||
|
|
||||||
|
params_1 = {**params, "k": 1}
|
||||||
|
retriever_1 = self.retriever_constructor(**params_1)
|
||||||
|
result_1 = retriever_1.invoke(self.retriever_query_example)
|
||||||
|
assert len(result_1) == 1
|
||||||
|
assert all(isinstance(doc, Document) for doc in result_1)
|
||||||
|
|
||||||
|
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
|
||||||
|
result_1 = retriever.invoke(self.retriever_query_example, k=1)
|
||||||
|
assert len(result_1) == 1
|
||||||
|
assert all(isinstance(doc, Document) for doc in result_1)
|
||||||
|
|
||||||
|
result_3 = retriever.invoke(self.retriever_query_example, k=3)
|
||||||
|
assert len(result_3) == 3
|
||||||
|
assert all(isinstance(doc, Document) for doc in result_3)
|
||||||
|
|
||||||
|
def test_invoke_returns_documents(self, retriever: BaseRetriever) -> None:
|
||||||
|
"""
|
||||||
|
If invoked with the example params, the retriever should return a list of
|
||||||
|
Documents.
|
||||||
|
"""
|
||||||
|
result = retriever.invoke(self.retriever_query_example)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert all(isinstance(doc, Document) for doc in result)
|
||||||
|
|
||||||
|
async def test_ainvoke_returns_documents(self, retriever: BaseRetriever) -> None:
|
||||||
|
"""
|
||||||
|
If ainvoked with the example params, the retriever should return a list of
|
||||||
|
Documents.
|
||||||
|
"""
|
||||||
|
result = await retriever.ainvoke(self.retriever_query_example)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert all(isinstance(doc, Document) for doc in result)
|
32
libs/standard-tests/tests/unit_tests/test_basic_retriever.py
Normal file
32
libs/standard-tests/tests/unit_tests/test_basic_retriever.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
from langchain_tests.integration_tests import RetrieversIntegrationTests
|
||||||
|
|
||||||
|
|
||||||
|
class ParrotRetriever(BaseRetriever):
|
||||||
|
parrot_name: str
|
||||||
|
k: int = 3
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
|
) -> list[Document]:
|
||||||
|
k = kwargs.get("k", self.k)
|
||||||
|
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
|
||||||
|
|
||||||
|
|
||||||
|
class TestParrotRetrieverIntegration(RetrieversIntegrationTests):
|
||||||
|
@property
|
||||||
|
def retriever_constructor(self) -> Type[ParrotRetriever]:
|
||||||
|
return ParrotRetriever
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retriever_constructor_params(self) -> dict:
|
||||||
|
return {"parrot_name": "Polly"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retriever_query_example(self) -> str:
|
||||||
|
return "parrot"
|
Loading…
Reference in New Issue
Block a user