mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 11:41:51 +00:00
add test
This commit is contained in:
parent
e00ddec3a6
commit
84c8f907ff
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
@ -15,9 +15,14 @@ class MockRetriever(BaseRetriever):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
filter: Optional[Callable[[Document], bool]] = None,
|
||||
) -> List[Document]:
|
||||
"""Return the documents"""
|
||||
return self.docs
|
||||
if filter is not None:
|
||||
retrieved_docs = [doc for doc in self.docs if filter(doc)]
|
||||
else:
|
||||
retrieved_docs = self.docs
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
@ -86,3 +91,11 @@ def test_invoke() -> None:
|
||||
# Additionally, the document with page_content "b" will be ranked 1st.
|
||||
assert len(ranked_documents) == 3
|
||||
assert ranked_documents[0].page_content == "b"
|
||||
|
||||
# Test kwargs are propagated
|
||||
def filter_function(document: Document) -> bool:
|
||||
return document.metadata["id"] == 1
|
||||
|
||||
ranked_documents = ensemble_retriever.invoke("_", filter=filter_function)
|
||||
assert len(ranked_documents) == 1
|
||||
assert ranked_documents[0].metadata["id"] == 1
|
||||
|
Loading…
Reference in New Issue
Block a user