From 84c8f907fffbd749c6744f005d0458014bf8a17a Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 25 Mar 2025 11:16:42 -0400 Subject: [PATCH] add test --- .../unit_tests/retrievers/test_ensemble.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index 4c5e9837c0b..f8f61d7e179 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Callable, List, Optional from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.documents import Document @@ -15,9 +15,14 @@ class MockRetriever(BaseRetriever): query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None, + filter: Optional[Callable[[Document], bool]] = None, ) -> List[Document]: """Return the documents""" - return self.docs + if filter is not None: + retrieved_docs = [doc for doc in self.docs if filter(doc)] + else: + retrieved_docs = self.docs + return retrieved_docs def test_invoke() -> None: @@ -86,3 +91,11 @@ def test_invoke() -> None: # Additionally, the document with page_content "b" will be ranked 1st. assert len(ranked_documents) == 3 assert ranked_documents[0].page_content == "b" + + # Test kwargs are propagated + def filter_function(document: Document) -> bool: + return document.metadata["id"] == 1 + + ranked_documents = ensemble_retriever.invoke("_", filter=filter_function) + assert len(ranked_documents) == 1 + assert ranked_documents[0].metadata["id"] == 1