This commit is contained in:
Chester Curme 2025-03-25 11:16:42 -04:00
parent e00ddec3a6
commit 84c8f907ff

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Callable, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
@ -15,9 +15,14 @@ class MockRetriever(BaseRetriever):
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
filter: Optional[Callable[[Document], bool]] = None,
) -> List[Document]:
"""Return the documents"""
return self.docs
if filter is not None:
retrieved_docs = [doc for doc in self.docs if filter(doc)]
else:
retrieved_docs = self.docs
return retrieved_docs
def test_invoke() -> None:
@ -86,3 +91,11 @@ def test_invoke() -> None:
# Additionally, the document with page_content "b" will be ranked 1st.
assert len(ranked_documents) == 3
assert ranked_documents[0].page_content == "b"
# Test kwargs are propagated
def filter_function(document: Document) -> bool:
return document.metadata["id"] == 1
ranked_documents = ensemble_retriever.invoke("_", filter=filter_function)
assert len(ranked_documents) == 1
assert ranked_documents[0].metadata["id"] == 1