Compare commits

...

2 Commits

Author SHA1 Message Date
Chester Curme
84c8f907ff add test 2025-03-25 11:16:42 -04:00
Chester Curme
e00ddec3a6 propagate kwargs 2025-03-25 11:15:32 -04:00
2 changed files with 27 additions and 6 deletions

View File

@@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever):
**kwargs,
)
try:
result = self.rank_fusion(input, run_manager=run_manager, config=config)
result = self.rank_fusion(
input, run_manager=run_manager, config=config, **kwargs
)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
@@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever):
)
try:
result = await self.arank_fusion(
input, run_manager=run_manager, config=config
input, run_manager=run_manager, config=config, **kwargs
)
except Exception as e:
await run_manager.on_retriever_error(e)
@@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever):
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Get the relevant documents for a given query.
@@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager)
fused_documents = self.rank_fusion(query, run_manager, **kwargs)
return fused_documents
@@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever):
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
@@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager)
fused_documents = await self.arank_fusion(query, run_manager, **kwargs)
return fused_documents
@@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: CallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve the results of the retrievers and use rank_fusion_func to get
@@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever):
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
),
**kwargs,
)
for i, retriever in enumerate(self.retrievers)
]
@@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: AsyncCallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously retrieve the results of the retrievers
@@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever):
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
**kwargs,
)
for i, retriever in enumerate(self.retrievers)
]

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Callable, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
@@ -15,9 +15,14 @@ class MockRetriever(BaseRetriever):
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
filter: Optional[Callable[[Document], bool]] = None,
) -> List[Document]:
"""Return the documents"""
return self.docs
if filter is not None:
retrieved_docs = [doc for doc in self.docs if filter(doc)]
else:
retrieved_docs = self.docs
return retrieved_docs
def test_invoke() -> None:
@@ -86,3 +91,11 @@ def test_invoke() -> None:
# Additionally, the document with page_content "b" will be ranked 1st.
assert len(ranked_documents) == 3
assert ranked_documents[0].page_content == "b"
# Test kwargs are propagated
def filter_function(document: Document) -> bool:
return document.metadata["id"] == 1
ranked_documents = ensemble_retriever.invoke("_", filter=filter_function)
assert len(ranked_documents) == 1
assert ranked_documents[0].metadata["id"] == 1