mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
feat(standard-tests): add a property to set the name of the parameter for the number of results to return (#32443)
Not all retrievers use `k` as param name to set the number of results to
return. Even in LangChain itself. Eg:
bc4251b9e0/libs/core/langchain_core/indexing/in_memory.py (L31)
So it's helpful to be able to change it for a given retriever.
The change also adds hints to disable the tests if the retriever doesn't
support setting the param in the constructor or in the invoke method
(for instance, the `InMemoryDocumentIndex` in the link supports in the
constructor but not in the invoke method).
This change is backward compatible.
---------
Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
e120604774
commit
a647073b26
@ -24,8 +24,14 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
@property
|
||||
@abstractmethod
|
||||
def retriever_query_example(self) -> str:
|
||||
"""Returns a str representing the "query" of an example retriever call."""
|
||||
...
|
||||
"""Returns a str representing the ``query`` of an example retriever call."""
|
||||
|
||||
@property
|
||||
def num_results_arg_name(self) -> str:
|
||||
"""Returns the name of the parameter for the number of results returned.
|
||||
|
||||
Usually something like ``k`` or ``top_k``."""
|
||||
return "k"
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> BaseRetriever:
|
||||
@ -33,14 +39,34 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
return self.retriever_constructor(**self.retriever_constructor_params)
|
||||
|
||||
def test_k_constructor_param(self) -> None:
|
||||
"""Test that the retriever constructor accepts a k parameter, representing
|
||||
"""Test the number of results constructor parameter.
|
||||
|
||||
Test that the retriever constructor accepts a parameter representing
|
||||
the number of documents to return.
|
||||
|
||||
By default, the parameter tested is named ``k``, but it can be overridden by
|
||||
setting the ``num_results_arg_name`` property.
|
||||
|
||||
.. note::
|
||||
If the retriever doesn't support configuring the number of results returned
|
||||
via the constructor, this test can be skipped using a pytest ``xfail`` on
|
||||
the test class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="This retriever doesn't support setting "
|
||||
"the number of results via the constructor."
|
||||
)
|
||||
def test_k_constructor_param(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, either the retriever constructor does not accept a k
|
||||
parameter, or the retriever does not return the correct number of documents
|
||||
(`k`) when it is set.
|
||||
If this test fails, the retriever constructor does not accept a number
|
||||
of results parameter, or the retriever does not return the correct number
|
||||
of documents ( of the one set in ``num_results_arg_name``) when it is
|
||||
set.
|
||||
|
||||
For example, a retriever like
|
||||
|
||||
@ -52,29 +78,51 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
|
||||
"""
|
||||
params = {
|
||||
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
|
||||
k: v
|
||||
for k, v in self.retriever_constructor_params.items()
|
||||
if k != self.num_results_arg_name
|
||||
}
|
||||
params_3 = {**params, "k": 3}
|
||||
params_3 = {**params, self.num_results_arg_name: 3}
|
||||
retriever_3 = self.retriever_constructor(**params_3)
|
||||
result_3 = retriever_3.invoke(self.retriever_query_example)
|
||||
assert len(result_3) == 3
|
||||
assert all(isinstance(doc, Document) for doc in result_3)
|
||||
|
||||
params_1 = {**params, "k": 1}
|
||||
params_1 = {**params, self.num_results_arg_name: 1}
|
||||
retriever_1 = self.retriever_constructor(**params_1)
|
||||
result_1 = retriever_1.invoke(self.retriever_query_example)
|
||||
assert len(result_1) == 1
|
||||
assert all(isinstance(doc, Document) for doc in result_1)
|
||||
|
||||
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
|
||||
"""Test that the invoke method accepts a k parameter, representing the number of
|
||||
documents to return.
|
||||
"""Test the number of results parameter in ``invoke()``.
|
||||
|
||||
Test that the invoke method accepts a parameter representing
|
||||
the number of documents to return.
|
||||
|
||||
By default, the parameter is named ``, but it can be overridden by
|
||||
setting the ``num_results_arg_name`` property.
|
||||
|
||||
.. note::
|
||||
If the retriever doesn't support configuring the number of results returned
|
||||
via the invoke method, this test can be skipped using a pytest ``xfail`` on
|
||||
the test class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="This retriever doesn't support setting "
|
||||
"the number of results in the invoke method."
|
||||
)
|
||||
def test_invoke_with_k_kwarg(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, the retriever's invoke method does not accept a k
|
||||
parameter, or the retriever does not return the correct number of documents
|
||||
(`k`) when it is set.
|
||||
If this test fails, the retriever's invoke method does not accept a number
|
||||
of results parameter, or the retriever does not return the correct number
|
||||
of documents (``k`` of the one set in ``num_results_arg_name``) when it is
|
||||
set.
|
||||
|
||||
For example, a retriever like
|
||||
|
||||
@ -85,11 +133,15 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
should return 3 documents when invoked with a query.
|
||||
|
||||
"""
|
||||
result_1 = retriever.invoke(self.retriever_query_example, k=1)
|
||||
result_1 = retriever.invoke(
|
||||
self.retriever_query_example, None, **{self.num_results_arg_name: 1}
|
||||
)
|
||||
assert len(result_1) == 1
|
||||
assert all(isinstance(doc, Document) for doc in result_1)
|
||||
|
||||
result_3 = retriever.invoke(self.retriever_query_example, k=3)
|
||||
result_3 = retriever.invoke(
|
||||
self.retriever_query_example, None, **{self.num_results_arg_name: 3}
|
||||
)
|
||||
assert len(result_3) == 3
|
||||
assert all(isinstance(doc, Document) for doc in result_3)
|
||||
|
||||
@ -100,8 +152,8 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, the retriever's invoke method does not return a list of
|
||||
`langchain_core.document.Document` objects. Please confirm that your
|
||||
`_get_relevant_documents` method returns a list of `Document` objects.
|
||||
``langchain_core.document.Document`` objects. Please confirm that your
|
||||
``_get_relevant_documents`` method returns a list of ``Document`` objects.
|
||||
"""
|
||||
result = retriever.invoke(self.retriever_query_example)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user