mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-30 04:45:23 +00:00
add test
This commit is contained in:
parent
e00ddec3a6
commit
84c8f907ff
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -15,9 +15,14 @@ class MockRetriever(BaseRetriever):
|
|||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||||
|
filter: Optional[Callable[[Document], bool]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return the documents"""
|
"""Return the documents"""
|
||||||
return self.docs
|
if filter is not None:
|
||||||
|
retrieved_docs = [doc for doc in self.docs if filter(doc)]
|
||||||
|
else:
|
||||||
|
retrieved_docs = self.docs
|
||||||
|
return retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
def test_invoke() -> None:
|
def test_invoke() -> None:
|
||||||
@ -86,3 +91,11 @@ def test_invoke() -> None:
|
|||||||
# Additionally, the document with page_content "b" will be ranked 1st.
|
# Additionally, the document with page_content "b" will be ranked 1st.
|
||||||
assert len(ranked_documents) == 3
|
assert len(ranked_documents) == 3
|
||||||
assert ranked_documents[0].page_content == "b"
|
assert ranked_documents[0].page_content == "b"
|
||||||
|
|
||||||
|
# Test kwargs are propagated
|
||||||
|
def filter_function(document: Document) -> bool:
|
||||||
|
return document.metadata["id"] == 1
|
||||||
|
|
||||||
|
ranked_documents = ensemble_retriever.invoke("_", filter=filter_function)
|
||||||
|
assert len(ranked_documents) == 1
|
||||||
|
assert ranked_documents[0].metadata["id"] == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user